use super::helpers::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{AxisId, AxisType, Op, UOp};
use smallvec::smallvec;
#[test]
fn test_fix_store_partition() {
let index = UOp::index_const(0);
let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let unroll = create_unroll_iota(0, 4);
let store = index.store_with_ranges(value, smallvec![unroll]);
let result = phase2_only(&store);
match result.op() {
Op::Contract { src, upcast_ranges } => {
assert_eq!(upcast_ranges, &[(0, 4)], "CONTRACT should have axis from UNROLL");
match src.op() {
Op::Store { ranges, .. } => {
assert!(
ranges.is_empty() || !ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. })),
"STORE ranges should not contain UNROLL"
);
}
other => panic!("Expected STORE inside CONTRACT, got {:?}", other),
}
}
Op::Store { ranges, .. } => {
assert!(!ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. })));
}
other => panic!("Expected CONTRACT or STORE, got {:?}", other),
}
}
#[test]
fn test_fix_store_mixed_ranges() {
let index = UOp::index_const(0);
let value = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let unroll = create_unroll_iota(0, 4);
let end = UOp::const_(DType::Index, ConstValue::Int(16));
let loop_range = UOp::range_axis(end, AxisId::Renumbered(1), AxisType::Loop);
let store = index.store_with_ranges(value, smallvec![unroll, loop_range.clone()]);
let result = phase2_only(&store);
match result.op() {
Op::Contract { src, upcast_ranges } => {
assert_eq!(upcast_ranges, &[(0, 4)], "CONTRACT should have UNROLL axis");
match src.op() {
Op::Store { ranges, .. } => {
assert_eq!(ranges.len(), 1, "STORE should have one non-UNROLL range");
assert!(
matches!(ranges[0].op(), Op::Range { axis_type: AxisType::Loop, .. }),
"Should preserve Loop range"
);
}
other => panic!("Expected STORE inside CONTRACT, got {:?}", other),
}
}
Op::Store { ranges, .. } => {
assert!(ranges.iter().any(|r| matches!(r.op(), Op::Range { .. })));
}
other => panic!("Expected CONTRACT or STORE, got {:?}", other),
}
}