use crate::dtype::RawDataType;
use crate::ndarray::flags::NdArrayFlags;
use crate::util::functions::pad;
use crate::{NdArray, Reshape};
impl<'a, T: RawDataType> NdArray<'a, T> {
pub fn broadcast_to(&'a self, shape: &[usize]) -> NdArray<'a, T> {
let broadcast_shape = broadcast_shape(&self.shape, shape);
let broadcast_stride = broadcast_stride(&self.stride, &broadcast_shape, &self.shape);
let mut result = unsafe { self.reshaped_view(broadcast_shape, broadcast_stride) };
result.flags -= NdArrayFlags::Writeable;
result
}
}
fn pad_dimensions(shape: &[usize], stride: &[usize], ndims: usize) -> (Vec<usize>, Vec<usize>) {
let n = ndims - shape.len();
let shape = pad(shape, 1, n);
let stride = pad(stride, 0, n);
(shape, stride)
}
fn broadcast_shape(shape: &[usize], to: &[usize]) -> Vec<usize> {
let to = to.to_vec();
if to.len() < shape.len() {
panic!("cannot broadcast {shape:?} to shape {to:?} with fewer dimensions")
}
let last_ndims = &to[to.len() - shape.len()..];
for axis in 0..shape.len() {
if shape[axis] != 1 && shape[axis] != last_ndims[axis] {
panic!("broadcasting {shape:?} is not compatible with the desired shape {to:?}");
}
}
to
}
pub(crate) fn broadcast_stride(stride: &[usize],
broadcast_shape: &[usize],
original_shape: &[usize]) -> Vec<usize> {
let ndims = broadcast_shape.len();
if ndims < original_shape.len() {
panic!("cannot broadcast {original_shape:?} to shape {broadcast_shape:?} with fewer dimensions");
}
let mut broadcast_stride = Vec::with_capacity(ndims);
let original_first_axis = ndims - original_shape.len();
broadcast_stride.resize(original_first_axis, 0);
for axis in original_first_axis..ndims {
let original_axis_length = original_shape[axis - original_first_axis];
if original_axis_length == 1 {
broadcast_stride.push(0);
} else if original_axis_length == broadcast_shape[axis] {
broadcast_stride.push(stride[axis - original_first_axis]);
} else {
panic!("broadcasting {original_shape:?} is not compatible with the desired shape {broadcast_shape:?}");
}
}
broadcast_stride
}
pub(crate) fn broadcast_shapes(first: &[usize], second: &[usize]) -> Vec<usize> {
let mut shape1;
let mut shape2;
if first.len() > second.len() {
shape1 = pad(second, 1, first.len());
shape2 = first.to_vec();
} else {
shape1 = pad(first, 1, second.len());
shape2 = second.to_vec();
}
for axis in 0..shape1.len() {
if shape1[axis] == 1 {
shape1[axis] = shape2[axis];
} else if shape2[axis] == 1 {
shape2[axis] = shape1[axis];
}
else if shape1[axis] != shape2[axis] {
panic!("broadcasting {first:?} is not compatible with the desired shape {second:?}");
}
}
shape1
}
pub(crate) fn get_broadcasted_axes(broadcast_shape: &[usize],
original_shape: &[usize]) -> Vec<isize> {
if broadcast_shape.len() < original_shape.len() {
panic!("cannot broadcast {original_shape:?} to shape {broadcast_shape:?} with fewer dimensions");
}
let ndims_diff = broadcast_shape.len() - original_shape.len();
let mut axes = Vec::new();
for i in 0..broadcast_shape.len() {
let to_dim = broadcast_shape[i];
let from_dim = if i < ndims_diff { 1 } else { original_shape[i - ndims_diff] };
if from_dim == 1 && to_dim > 1 || i < ndims_diff {
axes.push(i as isize);
}
}
axes
}
#[cfg(test)]
mod tests {
use crate::broadcast::{broadcast_shapes, get_broadcasted_axes};
#[test]
fn test_broadcast_shapes() {
let shape1 = vec![5, 1];
let shape2 = vec![2, 1, 3];
let correct = vec![2, 5, 3];
let output = broadcast_shapes(&shape1, &shape2);
assert_eq!(output, correct);
}
#[test]
fn test_get_broadcasted_axes() {
assert_eq!(get_broadcasted_axes(&[3, 3], &[3, 1]), vec![1]);
assert_eq!(get_broadcasted_axes(&[2, 3], &[3]), vec![0]);
assert_eq!(get_broadcasted_axes(&[8, 7, 6], &[7, 1]), vec![0, 2]);
assert_eq!(get_broadcasted_axes(&[4, 5, 6], &[1, 5, 1]), vec![0, 2]);
assert_eq!(get_broadcasted_axes(&[5, 6], &[1, 6]), vec![0]);
assert_eq!(get_broadcasted_axes(&[5, 6], &[5, 1]), vec![1]);
}
}