pub mod bufferize_unroll;
pub mod do_contract;
pub mod do_expand;
pub mod edge_cases;
pub mod end_unrolls;
pub mod fix_reduce;
pub mod fix_store;
pub mod group_reduce;
pub mod helpers;
pub mod shift_to_integration;
pub mod swizzle;
use crate::expand::*;
use morok_ir::{AxisType, prelude::*};
#[test]
fn test_pre_expand_passthrough() {
let end = UOp::const_(DType::Index, ConstValue::Int(32));
let range = UOp::range_axis(end, morok_ir::AxisId::Renumbered(0), AxisType::Reduce);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec::smallvec![range.clone()], ReduceOp::Add);
let result = pre_expand(&reduce);
if let Op::Reduce { ranges, .. } = result.op() {
assert_eq!(ranges.len(), 1);
assert!(matches!(ranges[0].op(), Op::Range { axis_type: AxisType::Reduce, .. }));
} else {
panic!("Expected REDUCE op");
}
}
#[test]
fn test_vectorize_expansion_with_mixed_sources() {
let values = UOp::vconst(vec![ConstValue::Int(0), ConstValue::Int(1), ConstValue::Int(2)], DType::Int64);
let unroll = values.unroll(vec![(0, 3)]);
let scalar = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let binary = UOp::new(Op::Binary(morok_ir::BinaryOp::Add, scalar.clone(), unroll.clone()), DType::Float32);
let result = pre_expand(&binary);
assert!(
matches!(result.op(), Op::Unroll { .. } | Op::Binary(..)),
"Expected UNROLL or Binary, got {:?}",
result.op()
);
}
#[test]
fn test_vectorize_all_scalar_sources() {
let scalar_a = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let scalar_b = UOp::const_(DType::Float32, ConstValue::Float(2.0));
let vectorize = UOp::vectorize(smallvec::smallvec![scalar_a, scalar_b]);
let result = pre_expand(&vectorize);
assert_eq!(result.dtype().vcount(), 2);
}
#[test]
fn test_fix_reduce_unroll_with_unroll_ops() {
let values =
UOp::vconst(vec![ConstValue::Int(0), ConstValue::Int(1), ConstValue::Int(2), ConstValue::Int(3)], DType::Int64);
let unroll = values.unroll(vec![(1, 4)]);
let reduce_end = UOp::const_(DType::Index, ConstValue::Int(16));
let reduce_range = UOp::range_axis(reduce_end, morok_ir::AxisId::Renumbered(0), AxisType::Reduce);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec::smallvec![reduce_range.clone(), unroll], ReduceOp::Add);
let result = fix_reduce_unroll(&reduce);
assert!(result.is_some(), "Expected Some when UNROLL is in ranges");
if let Some(fixed) = result
&& let Op::Reduce { src: fixed_src, ranges, .. } = fixed.op()
{
assert!(matches!(fixed_src.op(), Op::Contract { .. }), "Expected CONTRACT wrapper");
assert!(ranges.iter().all(|r| matches!(r.op(), Op::Range { .. })), "All ranges should be Range ops");
}
}
#[test]
fn test_reduce_empty_ranges_bug() {
let data_buf = UOp::buffer_id(Some(0));
let _data_val = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let index = UOp::index().buffer(data_buf).indices(vec![UOp::index_const(0)]).call().expect("index");
let reduce_end = UOp::const_(DType::Index, ConstValue::Int(2));
let reduce_range = UOp::range_axis(reduce_end, morok_ir::AxisId::Renumbered(0), AxisType::Reduce);
let values = UOp::vconst(vec![ConstValue::Int(0), ConstValue::Int(1)], DType::Int64);
let unroll = values.unroll(vec![(1, 2)]);
let reduce = index.reduce(smallvec::smallvec![reduce_range.clone(), unroll], ReduceOp::Add);
println!("BEFORE pre_expand:");
println!("REDUCE: {}", reduce.tree());
let result = pre_expand(&reduce);
println!("AFTER pre_expand:");
println!("RESULT: {}", result.tree());
if let Op::Reduce { ranges, .. } = result.op()
&& ranges.is_empty()
{
panic!(
"BUG: REDUCE has empty ranges after pre_expand - this causes horizontal_reduce to return unchanged input"
);
}
}