use super::helpers::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use smallvec::smallvec;
#[test]
fn test_empty_unroll_unwrap() {
let scalar = UOp::const_(DType::Float32, ConstValue::Float(42.0));
let unroll = scalar.unroll(vec![]);
let result = phase2_only(&unroll);
match result.op() {
Op::Const(cv) => {
assert_eq!(cv.0, ConstValue::Float(42.0), "Should unwrap to original scalar");
}
other => {
if std::sync::Arc::ptr_eq(&result, &scalar) {
} else {
panic!("Expected Const or same reference, got {:?}", other);
}
}
}
}
#[test]
fn test_double_unroll_collapse() {
let values = create_vconst_int(vec![0, 1, 2, 3]);
let inner_unroll = values.unroll(vec![(0, 4)]);
let outer_unroll = inner_unroll.unroll(vec![(1, 2)]);
let result = phase2_only(&outer_unroll);
match result.op() {
Op::Unroll { unroll_axes, .. } => {
assert_eq!(unroll_axes.len(), 2, "Should have combined axes");
assert!(unroll_axes.contains(&(0, 4)), "Should contain inner axis");
assert!(unroll_axes.contains(&(1, 2)), "Should contain outer axis");
}
other => panic!("Expected UNROLL, got {:?}", other),
}
}
#[test]
fn test_barrier_with_unroll() {
let values = create_vconst_int(vec![0, 1, 2, 3]);
let unroll = values.unroll(vec![(0, 4)]);
let barrier = UOp::new(Op::Barrier { src: unroll, deps: smallvec![] }, DType::Int64.vec(4));
let result = phase2_only(&barrier);
match result.op() {
Op::Unroll { src, unroll_axes } => {
assert_eq!(unroll_axes, &[(0, 4)], "Should preserve axes");
assert!(matches!(src.op(), Op::Barrier { .. }), "Inner should be BARRIER");
}
Op::Barrier { .. } => {
}
other => panic!("Expected UNROLL or BARRIER, got {:?}", other),
}
}
#[test]
fn test_contract_void_store() {
let void_op = UOp::noop();
let contract = void_op.contract(vec![(0, 4)]);
assert_eq!(contract.dtype(), DType::Void, "CONTRACT of void should be void");
let result = phase2_only(&contract);
match result.op() {
Op::Noop => {
}
Op::Contract { .. } => {
}
other => panic!("Expected Noop or Contract, got {:?}", other),
}
}