pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Vec<usize>> {
let ndim = a.len().max(b.len());
let mut result = vec![0usize; ndim];
for i in 0..ndim {
let da = if i < ndim - a.len() { 1 } else { a[i - (ndim - a.len())] };
let db = if i < ndim - b.len() { 1 } else { b[i - (ndim - b.len())] };
if da == db {
result[i] = da;
} else if da == 1 {
result[i] = db;
} else if db == 1 {
result[i] = da;
} else {
return None;
}
}
Some(result)
}
pub fn broadcast_strides(operand_shape: &[usize], output_shape: &[usize]) -> Vec<usize> {
let ndim = output_shape.len();
let offset = ndim - operand_shape.len();
let mut strides = vec![0usize; ndim];
let mut s = 1usize;
for d in (0..operand_shape.len()).rev() {
if operand_shape[d] == 1 {
strides[offset + d] = 0;
} else {
strides[offset + d] = s;
s *= operand_shape[d];
}
}
strides
}
#[derive(Debug, Clone)]
pub struct BroadcastInfo {
pub original_shape: Vec<usize>,
pub reduced_dims: Vec<usize>,
}
impl BroadcastInfo {
pub fn new(operand_shape: &[usize], output_shape: &[usize]) -> Option<Self> {
let ndim = output_shape.len();
let offset = ndim - operand_shape.len();
let mut reduced_dims = Vec::new();
for d in 0..offset {
reduced_dims.push(d);
}
for d in 0..operand_shape.len() {
if operand_shape[d] == 1 && output_shape[offset + d] > 1 {
reduced_dims.push(offset + d);
}
}
if reduced_dims.is_empty() {
None } else {
Some(Self {
original_shape: operand_shape.to_vec(),
reduced_dims,
})
}
}
}
pub fn suffix_products(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut suffix = vec![1usize; ndim];
for d in (0..ndim.saturating_sub(1)).rev() {
suffix[d] = suffix[d + 1] * shape[d + 1];
}
suffix
}
pub fn cpu_broadcast_binary(
a: &[f32],
b: &[f32],
out: &mut [f32],
output_shape: &[usize],
a_strides: &[usize],
b_strides: &[usize],
op: fn(f32, f32) -> f32,
) {
let numel: usize = output_shape.iter().product();
let ndim = output_shape.len();
let suffix = suffix_products(output_shape);
for i in 0..numel {
let mut a_idx = 0usize;
let mut b_idx = 0usize;
let mut remainder = i;
for d in 0..ndim {
let coord = remainder / suffix[d];
remainder %= suffix[d];
a_idx += coord * a_strides[d];
b_idx += coord * b_strides[d];
}
out[i] = op(a[a_idx], b[b_idx]);
}
}
pub fn cpu_reduce_sum(
src: &[f32],
dst: &mut [f32],
input_shape: &[usize],
reduced_dims: &[usize],
) {
let ndim = input_shape.len();
let in_numel: usize = input_shape.iter().product();
let suffix = suffix_products(input_shape);
let mut out_shape = input_shape.to_vec();
for &d in reduced_dims {
out_shape[d] = 1;
}
let out_suffix = suffix_products(&out_shape);
let out_numel: usize = out_shape.iter().product();
for v in dst.iter_mut().take(out_numel) {
*v = 0.0;
}
for i in 0..in_numel {
let mut out_idx = 0usize;
let mut remainder = i;
for d in 0..ndim {
let coord = remainder / suffix[d];
remainder %= suffix[d];
if !reduced_dims.contains(&d) {
out_idx += coord * out_suffix[d];
}
}
dst[out_idx] += src[i];
}
}