pub fn broadcast_shapes(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
let mut result = Vec::new();
let len1 = shape1.len();
let len2 = shape2.len();
let max_len = std::cmp::max(len1, len2);
for i in 0..max_len {
let dim1 = if i < len1 { shape1[len1 - 1 - i] } else { 1 };
let dim2 = if i < len2 { shape2[len2 - 1 - i] } else { 1 };
if dim1 == dim2 || dim1 == 1 || dim2 == 1 {
result.push(std::cmp::max(dim1, dim2));
} else {
return None;
}
}
result.reverse();
Some(result)
}
#[cfg(test)]
mod tests {
use super::broadcast_shapes;
#[test]
fn test_broadcast_shapes() {
let a = vec![3, 1];
let b = vec![1, 4];
let result = broadcast_shapes(&a, &b).unwrap();
assert_eq!(result, vec![3, 4]);
let a = vec![2, 3];
let b = vec![3];
let result = broadcast_shapes(&a, &b).unwrap();
assert_eq!(result, vec![2, 3]);
let a = vec![2];
let b = vec![3];
assert!(broadcast_shapes(&a, &b).is_none());
}
}