use super::helpers::*;
use morok_dtype::DType;
use morok_ir::UOp;
use morok_ir::types::ConstValue;
#[test]
fn test_expand_add_broadcast() {
let unroll = create_unroll_iota(1, 4);
let scalar = UOp::const_(DType::Int64, ConstValue::Int(3));
let add = unroll.try_add(&scalar).unwrap();
let result = phase2_only(&add);
assert_result_values(&result, &[3, 4, 5, 6]);
let (_, axes) = unwrap_unroll(&result);
assert_eq!(axes, vec![(1, 4)], "Should preserve axis");
}
#[test]
fn test_expand_same_axis() {
let e1 = create_unroll_iota(1, 4);
let e2 = create_unroll_scaled(1, 4, 4);
let add = e1.try_add(&e2).unwrap();
let result = phase2_only(&add);
assert_result_values(&result, &[0, 5, 10, 15]);
let (_, axes) = unwrap_unroll(&result);
assert_eq!(axes, vec![(1, 4)], "Should preserve axis");
}
#[test]
fn test_expand_different_axis() {
let e1 = create_unroll_scaled(1, 4, 4);
let e2 = create_unroll_iota(2, 4);
let add = e1.try_add(&e2).unwrap();
let result = phase2_only(&add);
let expected: Vec<i64> = (0..16).collect();
assert_result_values(&result, &expected);
let (_, axes) = unwrap_unroll(&result);
assert_eq!(axes, vec![(1, 4), (2, 4)], "Should have both axes");
}
#[test]
fn test_expand_different_axis_flip() {
let e2 = create_unroll_iota(2, 4);
let e1 = create_unroll_scaled(1, 4, 4);
let add = e2.try_add(&e1).unwrap();
let result = phase2_only(&add);
let expected: Vec<i64> = (0..16).collect();
assert_result_values(&result, &expected);
let (_, axes) = unwrap_unroll(&result);
assert_eq!(axes, vec![(1, 4), (2, 4)], "Should have both axes");
}
#[test]
fn test_expand_three_axes() {
let e1 = create_unroll_scaled(1, 4, 4);
let e2 = create_unroll_iota(2, 4);
let e3 = create_unroll_scaled(3, 4, 16);
let sum = e1.try_add(&e2).unwrap().try_add(&e3).unwrap();
let result = phase2_only(&sum);
let (src, axes) = unwrap_unroll(&result);
assert_eq!(axes, vec![(1, 4), (2, 4), (3, 4)], "Should have three axes");
assert_eq!(src.dtype().vcount(), 64, "Inner should be vec64");
}
#[test]
fn test_expand_mul_broadcast() {
let unroll = create_unroll_iota(1, 4);
let scalar = UOp::const_(DType::Int64, ConstValue::Int(2));
let mul = unroll.try_mul(&scalar).unwrap();
let result = phase2_only(&mul);
assert_result_values(&result, &[0, 2, 4, 6]);
}
#[test]
fn test_expand_mul_same_axis() {
let e1 = create_unroll_values(1, vec![1, 2, 3, 4]);
let e2 = create_unroll_values(1, vec![1, 2, 3, 4]);
let mul = e1.try_mul(&e2).unwrap();
let result = phase2_only(&mul);
assert_result_values(&result, &[1, 4, 9, 16]);
}
#[test]
fn test_expand_sub_broadcast() {
let unroll = create_unroll_values(1, vec![10, 20, 30, 40]);
let scalar = UOp::const_(DType::Int64, ConstValue::Int(5));
let sub = unroll.try_sub(&scalar).unwrap();
let result = phase2_only(&sub);
assert_result_values(&result, &[5, 15, 25, 35]);
}
#[test]
fn test_expand_compound_expression() {
let e1 = create_unroll_iota(1, 4);
let e2 = create_unroll_values(1, vec![2, 2, 2, 2]);
let scalar = UOp::const_(DType::Int64, ConstValue::Int(1));
let sum = e1.try_add(&scalar).unwrap();
let result = phase2_only(&sum.try_mul(&e2).unwrap());
assert_result_values(&result, &[2, 4, 6, 8]);
}