use super::helpers::*;
use morok_ir::{Op, UOp};
#[test]
fn test_contract_simple() {
let unroll = create_unroll_iota(1, 4);
let contract = create_contract(unroll, vec![(1, 4)]);
let result = phase2_only(&contract);
assert_result_values(&result, &[0, 1, 2, 3]);
}
#[test]
fn test_contract_partial_axis_1() {
let unroll = create_unroll_multi_axis(vec![(1, 4), (2, 4)]);
let contract = create_contract(unroll, vec![(1, 4)]);
let result = phase2_only(&contract);
let (gep, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(2, 4)], "Should have axis 2 remaining");
let (_, indices) = unwrap_gep(&gep);
assert_eq!(
indices,
vec![0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15],
"GEP indices for axis 1 contraction"
);
let values = extract_result_values(&gep);
assert_eq!(values, vec![0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15], "Contracted values from axis 1");
}
#[test]
fn test_contract_partial_axis_2() {
let unroll = create_unroll_multi_axis(vec![(1, 4), (2, 4)]);
let contract = create_contract(unroll, vec![(2, 4)]);
let result = phase2_only(&contract);
let (gep, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(1, 4)], "Should have axis 1 remaining");
let (_, indices) = unwrap_gep(&gep);
assert_eq!(
indices,
vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
"GEP indices for axis 2 contraction"
);
let values = extract_result_values(&gep);
assert_eq!(values, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], "Contracted values from axis 2");
}
#[test]
fn test_contract_four_axes() {
let unroll = create_unroll_multi_axis(vec![(1, 2), (2, 2), (3, 2), (4, 2)]);
let contract = create_contract(unroll, vec![(2, 2)]);
let result = phase2_only(&contract);
let (_, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(1, 2), (3, 2), (4, 2)], "Should have axes 1, 3, 4 remaining");
}
#[test]
fn test_contract_multi_axis_order_1() {
let unroll = create_unroll_multi_axis(vec![(1, 2), (2, 2), (3, 2)]);
let contract = create_contract(unroll, vec![(1, 2), (2, 2)]);
let result = phase2_only(&contract);
let (_, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(3, 2)], "Should have axis 3 remaining");
}
#[test]
fn test_contract_multi_axis_order_2() {
let unroll = create_unroll_multi_axis(vec![(1, 2), (2, 2), (3, 2)]);
let contract = create_contract(unroll, vec![(2, 2), (3, 2)]);
let result = phase2_only(&contract);
let (_, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(1, 2)], "Should have axis 1 remaining");
}
#[test]
fn test_contract_middle_axis() {
let unroll = create_unroll_multi_axis(vec![(1, 2), (2, 2), (3, 2)]);
let contract = create_contract(unroll, vec![(2, 2)]);
let result = phase2_only(&contract);
let (gep, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(1, 2), (3, 2)], "Should have axes 1, 3 remaining");
let (_, indices) = unwrap_gep(&gep);
assert_eq!(indices, vec![0, 2, 1, 3, 4, 6, 5, 7], "GEP indices for middle axis contraction");
let values = extract_result_values(&gep);
assert_eq!(values, vec![0, 2, 1, 3, 4, 6, 5, 7], "Contracted values from middle axis");
}
#[test]
fn test_contract_non_unroll_source() {
let scalar = UOp::const_(morok_dtype::DType::Int64, morok_ir::types::ConstValue::Int(4));
let contract = create_contract(scalar, vec![(0, 4)]);
let result = phase2_only(&contract);
assert_is_vectorize(&result, 4);
assert_result_values(&result, &[4, 4, 4, 4]);
}
#[test]
fn test_contract_partial_expansion() {
let unroll = create_unroll_iota(1, 4);
let contract = create_contract(unroll, vec![(0, 2), (1, 4)]);
let result = phase2_only(&contract);
match result.op() {
Op::Gep { indices, .. } => {
assert_eq!(indices, &[0, 1, 2, 3, 0, 1, 2, 3], "Should duplicate for missing axis");
}
other => panic!("Expected GEP, got {:?}", other),
}
assert_result_values(&result, &[0, 1, 2, 3, 0, 1, 2, 3]);
}
#[test]
fn test_contract_partial_dtype_validation() {
let unroll = create_unroll_multi_axis(vec![(1, 4), (2, 2)]);
let contract = create_contract(unroll, vec![(1, 4)]);
let result = phase2_only(&contract);
assert_vcount(&result, 4);
let (_, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(2, 2)]);
}
#[test]
fn test_contract_partial_dtype_same_sizes() {
let unroll = create_unroll_multi_axis(vec![(1, 4), (2, 4)]);
let contract = create_contract(unroll, vec![(1, 4)]);
let result = phase2_only(&contract);
assert_vcount(&result, 4);
let (_, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(2, 4)]);
}
#[test]
fn test_contract_void_dtype_preserved() {
use super::helpers::create_contract_void;
use super::helpers::create_unroll_multi_axis_with_dtype;
use morok_dtype::DType;
let unroll = create_unroll_multi_axis_with_dtype(vec![(1, 4), (2, 4)], DType::Void);
let contract = create_contract_void(unroll, vec![(1, 4)]);
let result = phase2_only(&contract);
assert_eq!(result.dtype(), DType::Void);
let (_, remaining_axes) = unwrap_unroll(&result);
assert_eq!(remaining_axes, vec![(2, 4)]);
}
#[test]
fn test_contract_full_uses_output_dtype() {
let unroll = create_unroll_iota(1, 4);
let contract = create_contract(unroll, vec![(1, 4)]);
let result = phase2_only(&contract);
assert_vcount(&result, 4);
assert_result_values(&result, &[0, 1, 2, 3]);
}