use super::helpers::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{AxisId, AxisType, Op, ReduceOp, UOp};
use smallvec::smallvec;
#[test]
fn test_fix_reduce_simple_passthrough() {
let end = UOp::const_(DType::Index, ConstValue::Int(32));
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Reduce);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![range.clone()], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
assert_eq!(ranges.len(), 1, "Should still have one range");
assert!(matches!(ranges[0].op(), Op::Range { axis_type: AxisType::Reduce, .. }));
assert!(!matches!(fixed_src.op(), Op::Contract { .. }), "Should not have CONTRACT wrapper");
}
other => panic!("Expected REDUCE, got {:?}", other),
}
}
#[test]
fn test_fix_reduce_loop_passthrough() {
let end = UOp::const_(DType::Index, ConstValue::Int(16));
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let reduce = src.reduce(smallvec![range.clone()], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
assert_eq!(ranges.len(), 1, "Should still have one range");
assert!(matches!(ranges[0].op(), Op::Range { axis_type: AxisType::Loop, .. }));
assert!(!matches!(fixed_src.op(), Op::Contract { .. }), "Should not have CONTRACT wrapper");
}
other => panic!("Expected REDUCE, got {:?}", other),
}
}
#[test]
fn test_fix_reduce_range_unroll() {
let end = UOp::const_(DType::Index, ConstValue::Int(4));
let unroll_range = UOp::range_axis(end.clone(), AxisId::Renumbered(1), AxisType::Unroll);
let reduce_range =
UOp::range_axis(UOp::const_(DType::Index, ConstValue::Int(32)), AxisId::Renumbered(0), AxisType::Reduce);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![reduce_range, unroll_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
for range in ranges.iter() {
assert!(!matches!(range.op(), Op::Unroll { .. }), "UNROLL should be removed from ranges");
}
assert!(
matches!(fixed_src.op(), Op::Contract { .. } | Op::Vectorize { .. } | Op::VConst { .. }),
"Source should be expanded (CONTRACT/VECTORIZE/VCONST), got {:?}",
fixed_src.op()
);
}
other => panic!("Expected REDUCE, got {:?}", other),
}
}
#[test]
fn test_fix_reduce_unroll_vectorizes_source() {
let unroll_end = UOp::const_(DType::Index, ConstValue::Int(4));
let unroll_range = UOp::range_axis(unroll_end, AxisId::Renumbered(1), AxisType::Unroll);
let src = UOp::define_var("x".into(), 0, 100).cast(DType::Float32);
let reduce = src.reduce(smallvec![unroll_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let vectorized = match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
assert!(ranges.is_empty() || !ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. })));
fixed_src.clone()
}
Op::Vectorize { .. } => result.clone(),
other => panic!("Expected REDUCE or VECTORIZE, got {:?}", other),
};
if let Op::Vectorize { elements } = vectorized.op() {
assert_eq!(elements.len(), 4, "Should broadcast to 4 elements");
for elem in elements.iter() {
assert!(
matches!(elem.op(), Op::Cast { .. } | Op::DefineVar { .. }),
"VECTORIZE elements should be Cast or DefineVar, got {:?}",
elem.op()
);
}
} else {
panic!("Expected VECTORIZE, got {:?}", vectorized.op());
}
}
#[test]
fn test_fix_reduce_range_upcast() {
let upcast_end = UOp::const_(DType::Index, ConstValue::Int(4));
let upcast_range = UOp::range_axis(upcast_end, AxisId::Renumbered(1), AxisType::Upcast);
let loop_end = UOp::const_(DType::Index, ConstValue::Int(16));
let loop_range = UOp::range_axis(loop_end, AxisId::Renumbered(2), AxisType::Loop);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![upcast_range, loop_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
assert!(
matches!(fixed_src.op(), Op::Contract { .. } | Op::Vectorize { .. } | Op::VConst { .. }),
"Source should be expanded (CONTRACT/VECTORIZE/VCONST), got {:?}",
fixed_src.op()
);
assert!(ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type: AxisType::Loop, .. })));
}
other => panic!("Expected REDUCE, got {:?}", other),
}
}
#[test]
fn test_fix_reduce_multiple_unrolls() {
let unroll1_end = UOp::const_(DType::Index, ConstValue::Int(2));
let unroll1_range = UOp::range_axis(unroll1_end, AxisId::Renumbered(1), AxisType::Unroll);
let unroll2_end = UOp::const_(DType::Index, ConstValue::Int(2));
let unroll2_range = UOp::range_axis(unroll2_end, AxisId::Renumbered(2), AxisType::Unroll);
let src = UOp::define_var("x".into(), 0, 100).cast(DType::Float32);
let reduce = src.reduce(smallvec![unroll1_range, unroll2_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let vectorized = match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
assert!(ranges.is_empty() || !ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. })));
fixed_src.clone()
}
Op::Vectorize { .. } => result.clone(),
other => panic!("Expected REDUCE or VECTORIZE, got {:?}", other),
};
assert_eq!(vectorized.dtype().vcount(), 4, "Combined UNROLL should vectorize to 4 elements");
}
#[test]
fn test_fix_reduce_mixed_ranges() {
let reduce_end = UOp::const_(DType::Index, ConstValue::Int(32));
let reduce_range = UOp::range_axis(reduce_end, AxisId::Renumbered(0), AxisType::Reduce);
let unroll_end = UOp::const_(DType::Index, ConstValue::Int(4));
let unroll_range = UOp::range_axis(unroll_end, AxisId::Renumbered(1), AxisType::Unroll);
let loop_end = UOp::const_(DType::Index, ConstValue::Int(8));
let loop_range = UOp::range_axis(loop_end, AxisId::Renumbered(2), AxisType::Loop);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![reduce_range, unroll_range, loop_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
let has_reduce = ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type: AxisType::Reduce, .. }));
let has_loop = ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type: AxisType::Loop, .. }));
let has_unroll = ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. }));
assert!(has_reduce, "Should preserve Reduce range");
assert!(has_loop, "Should preserve Loop range");
assert!(!has_unroll, "UNROLL should be removed from ranges");
assert!(
matches!(fixed_src.op(), Op::Contract { .. } | Op::Vectorize { .. } | Op::VConst { .. }),
"Source should be expanded"
);
}
other => panic!("Expected REDUCE, got {:?}", other),
}
}
#[test]
fn test_fix_reduce_unroll_source_with_unroll_range() {
let var = UOp::define_var("x".into(), 0, 100).cast(DType::Float32);
let src = var.unroll_with_dtype(vec![(1, 4)], DType::Float32);
let unroll_end = UOp::const_(DType::Index, ConstValue::Int(4));
let unroll_range = UOp::range_axis(unroll_end, AxisId::Renumbered(1), AxisType::Unroll);
let reduce = src.reduce(smallvec![unroll_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { ranges, .. } => {
assert!(!ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. })));
}
Op::Vectorize { .. } | Op::VConst { .. } | Op::Unroll { .. } => {
}
other => panic!("Expected REDUCE, VECTORIZE, VCONST, or UNROLL, got {:?}", other),
}
}
#[test]
fn test_fix_reduce_single_unroll_only() {
let unroll_end = UOp::const_(DType::Index, ConstValue::Int(4));
let unroll_range = UOp::range_axis(unroll_end, AxisId::Renumbered(0), AxisType::Unroll);
let src = UOp::define_var("x".into(), 0, 100).cast(DType::Float32);
let reduce = src.reduce(smallvec![unroll_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let vectorized = match result.op() {
Op::Reduce { src: fixed_src, ranges, .. } => {
assert!(!ranges.iter().any(|r| matches!(r.op(), Op::Unroll { .. })));
fixed_src.clone()
}
Op::Vectorize { .. } => result.clone(),
other => panic!("Expected REDUCE or VECTORIZE, got {:?}", other),
};
assert_eq!(vectorized.dtype().vcount(), 4, "Source should be vec4");
}
#[test]
fn test_fix_reduce_unroll_size_1() {
let unroll_end = UOp::const_(DType::Index, ConstValue::Int(1));
let unroll_range = UOp::range_axis(unroll_end, AxisId::Renumbered(0), AxisType::Unroll);
let src = UOp::define_var("x".into(), 0, 100).cast(DType::Float32);
let reduce = src.reduce(smallvec![unroll_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
assert!(
matches!(result.op(), Op::Reduce { .. } | Op::Vectorize { .. } | Op::DefineVar { .. }),
"Expected REDUCE, VECTORIZE, or DefineVar, got {:?}",
result.op()
);
}