use alloc::vec;
use alloc::vec::Vec;
use burn_std::Shape;
use crate::{FlexTensor, Layout};
pub fn broadcast_shape(lhs: &Shape, rhs: &Shape) -> Shape {
let max_dims = lhs.num_dims().max(rhs.num_dims());
let mut result = vec![0; max_dims];
for (i, out) in result.iter_mut().enumerate() {
let lhs_idx = i as isize + lhs.num_dims() as isize - max_dims as isize;
let rhs_idx = i as isize + rhs.num_dims() as isize - max_dims as isize;
let lhs_dim = if lhs_idx >= 0 {
lhs[lhs_idx as usize]
} else {
1
};
let rhs_dim = if rhs_idx >= 0 {
rhs[rhs_idx as usize]
} else {
1
};
if lhs_dim == rhs_dim {
*out = lhs_dim;
} else if lhs_dim == 1 {
*out = rhs_dim;
} else if rhs_dim == 1 {
*out = lhs_dim;
} else {
panic!(
"broadcast_shape: incompatible dimensions {} and {} at position {}",
lhs_dim, rhs_dim, i
);
}
}
Shape::from(result)
}
pub fn broadcast_binary(lhs: FlexTensor, rhs: FlexTensor) -> (FlexTensor, FlexTensor) {
let lhs_shape = lhs.layout().shape().clone();
let rhs_shape = rhs.layout().shape().clone();
if lhs_shape == rhs_shape {
return (lhs, rhs);
}
let target = broadcast_shape(&lhs_shape, &rhs_shape);
let lhs_expanded = if lhs_shape == target {
lhs
} else {
expand(lhs, target.clone())
};
let rhs_expanded = if rhs_shape == target {
rhs
} else {
expand(rhs, target)
};
(lhs_expanded, rhs_expanded)
}
pub fn expand(tensor: FlexTensor, target_shape: Shape) -> FlexTensor {
let src_dims = tensor.layout().shape().to_vec();
let src_strides = tensor.layout().strides().to_vec();
let start_offset = tensor.layout().start_offset();
let dtype = tensor.dtype();
let src_ndims = src_dims.len();
let target_ndims = target_shape.num_dims();
assert!(
target_ndims >= src_ndims,
"expand: target rank ({}) must be >= source rank ({}); \
broadcasting cannot drop dimensions",
target_ndims,
src_ndims
);
let dim_diff = target_ndims - src_ndims;
let mut new_strides = Vec::with_capacity(target_ndims);
for i in 0..target_ndims {
let target_dim = target_shape[i];
if i < dim_diff {
new_strides.push(0);
} else {
let src_idx = i - dim_diff;
let src_dim = src_dims[src_idx];
let src_stride = src_strides[src_idx];
if src_dim == target_dim {
new_strides.push(src_stride);
} else if src_dim == 1 {
new_strides.push(0);
} else {
panic!(
"expand: cannot expand dimension {} from {} to {}",
i, src_dim, target_dim
);
}
}
}
let new_layout = Layout::new(target_shape, new_strides, start_offset);
FlexTensor::from_arc(tensor.data_arc(), new_layout, dtype)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::TensorData;
#[test]
fn test_expand_1d_to_2d() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0], [3]));
let expanded = expand(tensor, Shape::new([2, 3]));
assert_eq!(expanded.layout().shape().to_vec(), vec![2, 3]);
assert_eq!(expanded.layout().strides(), &[0, 1]);
}
#[test]
fn test_expand_broadcast_dim() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0], [3, 1]));
let expanded = expand(tensor, Shape::new([3, 4]));
assert_eq!(expanded.layout().shape().to_vec(), vec![3, 4]);
assert_eq!(expanded.layout().strides()[1], 0);
}
#[test]
fn test_expand_same_shape() {
let tensor = FlexTensor::from_data(TensorData::new(
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
[2, 3],
));
let original_strides = tensor.layout().strides().to_vec();
let expanded = expand(tensor, Shape::new([2, 3]));
assert_eq!(expanded.layout().shape().to_vec(), vec![2, 3]);
assert_eq!(expanded.layout().strides(), &original_strides);
}
#[test]
fn test_expand_transposed() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]));
let transposed = tensor.transpose(0, 1);
assert!(!transposed.is_contiguous());
assert_eq!(transposed.layout().strides(), &[1, 2]);
let expanded = expand(transposed, Shape::new([3, 2, 2]));
assert_eq!(expanded.layout().shape().to_vec(), vec![3, 2, 2]);
assert_eq!(expanded.layout().strides(), &[0, 1, 2]);
let data: Vec<f32> = expanded.into_data().to_vec().unwrap();
assert_eq!(
data,
vec![1.0, 3.0, 2.0, 4.0, 1.0, 3.0, 2.0, 4.0, 1.0, 3.0, 2.0, 4.0]
);
}
#[test]
fn test_expand_flipped_1d() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0], [3]));
let flipped = crate::ops::flip::flip(tensor, &[0]);
assert!(flipped.layout().strides()[0] < 0);
let expanded = expand(flipped, Shape::new([2, 3]));
assert_eq!(expanded.layout().shape().to_vec(), vec![2, 3]);
assert_eq!(expanded.layout().strides()[0], 0);
assert!(expanded.layout().strides()[1] < 0);
let data: Vec<f32> = expanded.into_data().to_vec().unwrap();
assert_eq!(data, vec![3.0, 2.0, 1.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_expand_flipped_2d_preserves_negative_stride() {
let tensor = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [2, 2]));
let flipped = crate::ops::flip::flip(tensor, &[0]);
assert!(flipped.layout().strides()[0] < 0);
assert_eq!(flipped.layout().strides()[1], 1);
let expanded = expand(flipped, Shape::new([3, 2, 2]));
assert_eq!(expanded.layout().shape().to_vec(), vec![3, 2, 2]);
assert_eq!(expanded.layout().strides()[0], 0);
assert!(expanded.layout().strides()[1] < 0);
assert_eq!(expanded.layout().strides()[2], 1);
let data: Vec<f32> = expanded.into_data().to_vec().unwrap();
assert_eq!(
data,
vec![3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0]
);
}
#[test]
fn test_expand_narrowed_preserves_offset() {
let tensor = FlexTensor::from_data(TensorData::new(vec![0.0f32, 1.0, 2.0, 3.0, 4.0], [5]));
let narrowed = tensor.narrow(0, 1, 3);
assert_eq!(narrowed.layout().start_offset(), 1);
let expanded = expand(narrowed, Shape::new([2, 3]));
assert_eq!(expanded.layout().shape().to_vec(), vec![2, 3]);
assert_eq!(expanded.layout().start_offset(), 1);
let data: Vec<f32> = expanded.into_data().to_vec().unwrap();
assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
#[should_panic(expected = "broadcasting cannot drop dimensions")]
fn test_expand_to_fewer_dims_panics() {
let tensor = FlexTensor::from_data(TensorData::new(alloc::vec![0.0f32; 24], [2, 3, 4]));
let _ = expand(tensor, Shape::new([2, 3]));
}
#[test]
#[should_panic(expected = "broadcasting cannot drop dimensions")]
fn test_expand_1d_to_scalar_panics() {
let tensor = FlexTensor::from_data(TensorData::new(alloc::vec![1.0f32, 2.0, 3.0], [3]));
let _ = expand(tensor, Shape::from(alloc::vec::Vec::<usize>::new()));
}
#[test]
fn test_broadcast_binary_with_flipped() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], [4]));
let lhs = crate::ops::flip::flip(lhs, &[0]);
assert!(lhs.layout().strides()[0] < 0);
let rhs = FlexTensor::from_data(TensorData::new(vec![10.0f32, 20.0], [2, 1]));
let (lhs_bc, rhs_bc) = broadcast_binary(lhs, rhs);
assert_eq!(lhs_bc.layout().shape().to_vec(), vec![2, 4]);
assert_eq!(rhs_bc.layout().shape().to_vec(), vec![2, 4]);
assert_eq!(lhs_bc.layout().strides()[0], 0);
assert!(lhs_bc.layout().strides()[1] < 0);
let lhs_data: Vec<f32> = lhs_bc.into_data().to_vec().unwrap();
assert_eq!(lhs_data, vec![4.0, 3.0, 2.0, 1.0, 4.0, 3.0, 2.0, 1.0]);
}
}