pub fn broadcast_shape(a: &[usize], b: &[usize]) -> Option<Vec<usize>> {
let max_ndim = a.len().max(b.len());
let mut result = Vec::with_capacity(max_ndim);
for i in 0..max_ndim {
let a_dim = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
let b_dim = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
if a_dim == b_dim {
result.push(a_dim);
} else if a_dim == 1 {
result.push(b_dim);
} else if b_dim == 1 {
result.push(a_dim);
} else {
return None; }
}
result.reverse();
Some(result)
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Div,
Pow,
Max,
Min,
Atan2,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum UnaryOp {
Neg,
Abs,
Sign,
Sqrt,
Rsqrt,
Square,
Cbrt,
Recip,
Exp,
Exp2,
Expm1,
Log,
Log2,
Log10,
Log1p,
Sin,
Cos,
Tan,
Asin,
Acos,
Atan,
Sinh,
Cosh,
Tanh,
Asinh,
Acosh,
Atanh,
Floor,
Ceil,
Round,
Trunc,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CompareOp {
Eq,
Ne,
Lt,
Le,
Gt,
Ge,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_broadcast_shape() {
assert_eq!(broadcast_shape(&[2, 3], &[2, 3]), Some(vec![2, 3]));
assert_eq!(broadcast_shape(&[2, 3], &[1, 3]), Some(vec![2, 3]));
assert_eq!(broadcast_shape(&[2, 1], &[2, 3]), Some(vec![2, 3]));
assert_eq!(broadcast_shape(&[3], &[2, 3]), Some(vec![2, 3]));
assert_eq!(broadcast_shape(&[2, 3], &[3]), Some(vec![2, 3]));
assert_eq!(broadcast_shape(&[2, 3], &[2, 4]), None);
assert_eq!(broadcast_shape(&[3], &[4]), None);
}
}