burn_tensor/tensor/ops/
binary.rs

1
2
3
4
5
6
7
8
9
10
11
12
use alloc::vec::Vec;

/// Computes the output shape for binary operations with broadcasting support.
pub fn binary_ops_shape(lhs: &[usize], rhs: &[usize]) -> Vec<usize> {
    let mut shape_out = Vec::with_capacity(lhs.len());

    for (l, r) in lhs.iter().zip(rhs.iter()) {
        shape_out.push(usize::max(*l, *r));
    }

    shape_out
}