use super::helpers::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{AddrSpace, AxisId, AxisType, Op, ReduceOp, UOp};
use smallvec::smallvec;
use std::sync::Arc;
fn create_group_reduce_range(axis_id: usize, size: i64) -> Arc<UOp> {
let end = UOp::const_(DType::Index, ConstValue::Int(size));
UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::GroupReduce)
}
fn create_local_range(axis_id: usize, size: i64) -> Arc<UOp> {
let end = UOp::const_(DType::Index, ConstValue::Int(size));
UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Local)
}
fn create_reduce_range(axis_id: usize, size: i64) -> Arc<UOp> {
let end = UOp::const_(DType::Index, ConstValue::Int(size));
UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Reduce)
}
#[test]
fn test_passthrough_no_group_reduce() {
let reduce_range = create_reduce_range(0, 32);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![reduce_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
match result.op() {
Op::Reduce { ranges, .. } => {
assert_eq!(ranges.len(), 1, "Should have single range");
assert!(
matches!(ranges[0].op(), Op::Range { axis_type: AxisType::Reduce, .. }),
"Range should be Reduce type"
);
}
other => panic!("Expected REDUCE, got {:?}", other),
}
}
#[test]
fn test_group_reduce_basic_transformation() {
let group_range = create_group_reduce_range(0, 16);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![group_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let all_nodes = result.toposort();
let has_local_buf = all_nodes.iter().any(|n| {
matches!(n.op(), Op::Bufferize { opts, .. }
if opts.addrspace == AddrSpace::Local)
});
assert!(has_local_buf, "Should create LOCAL BUFFERIZE for shared memory");
if let Op::Reduce { ranges, .. } = result.op() {
for range in ranges.iter() {
if let Op::Range { axis_id, axis_type, .. } = range.op() {
assert_eq!(*axis_type, AxisType::Reduce, "Final ranges should be Reduce type");
assert!(
axis_id.value() >= 100,
"Ranges should be renumbered (axis_id >= 100), got {}",
axis_id.value()
);
}
}
} else {
panic!("Expected REDUCE at top level, got {:?}", result.op());
}
}
#[test]
fn test_group_reduce_with_mixed_ranges() {
let group_range = create_group_reduce_range(0, 16);
let reduce_range = create_reduce_range(1, 32);
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let reduce = src.reduce(smallvec![group_range, reduce_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let all_nodes = result.toposort();
let has_local_buf = all_nodes.iter().any(|n| {
matches!(n.op(), Op::Bufferize { opts, .. }
if opts.addrspace == AddrSpace::Local)
});
assert!(has_local_buf, "Should create LOCAL BUFFERIZE");
let has_group_reduce_in_final = all_nodes.iter().any(|n| {
if let Op::Reduce { ranges, .. } = n.op() {
ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type: AxisType::GroupReduce, .. }))
} else {
false
}
});
assert!(!has_group_reduce_in_final, "GROUP_REDUCE should be transformed out of final REDUCEs");
}
#[test]
fn test_group_reduce_with_local_ranges() {
let local_range = create_local_range(0, 32);
let group_range = create_group_reduce_range(1, 16);
let sixteen = UOp::index_const(16);
let addr = local_range.clone().try_mul(&sixteen).unwrap().try_add(&group_range.clone()).unwrap();
let src = addr.cast(DType::Float32);
let reduce = src.reduce(smallvec![group_range.clone()], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let all_nodes = result.toposort();
let has_local_buf = all_nodes.iter().any(|n| {
matches!(n.op(), Op::Bufferize { opts, .. }
if opts.addrspace == AddrSpace::Local)
});
assert!(has_local_buf, "Should have LOCAL BUFFERIZE for GROUP_REDUCE");
for node in all_nodes.iter() {
if let Op::Bufferize { ranges, opts, .. } = node.op()
&& opts.addrspace == AddrSpace::Local
{
let has_local_in_ranges =
ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type: AxisType::Local, .. }));
assert!(has_local_in_ranges, "BUFFERIZE ranges should include LOCAL range for shared memory indexing");
}
}
}
#[test]
fn test_group_reduce_preserves_reduce_op() {
for reduce_op in [ReduceOp::Add, ReduceOp::Max, ReduceOp::Mul] {
let group_range = create_group_reduce_range(0, 8);
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let reduce = src.reduce(smallvec![group_range], reduce_op);
let result = expander_rewrite(&reduce);
if let Op::Reduce { reduce_op: final_op, .. } = result.op() {
assert_eq!(*final_op, reduce_op, "Reduce operation should be preserved");
} else {
panic!("Expected REDUCE at top level for {:?}", reduce_op);
}
}
}
#[test]
fn test_pm_group_for_reduce_in_pipeline() {
use crate::expand::pre_expand;
let group_range = create_group_reduce_range(0, 16);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![group_range], ReduceOp::Add);
let expanded = pre_expand(&reduce);
let has_group_reduce_in_reduce = expanded.toposort().iter().any(|n| {
if let Op::Reduce { ranges, .. } = n.op() {
ranges.iter().any(|r| matches!(r.op(), Op::Range { axis_type: AxisType::GroupReduce, .. }))
} else {
false
}
});
assert!(!has_group_reduce_in_reduce, "GROUP_REDUCE should be transformed by pm_group_for_reduce");
}
#[test]
fn test_multiple_group_reduce_ranges() {
let group_range1 = create_group_reduce_range(0, 8);
let group_range2 = create_group_reduce_range(1, 4);
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let reduce = src.reduce(smallvec![group_range1, group_range2], ReduceOp::Add);
let result = expander_rewrite(&reduce);
let all_nodes = result.toposort();
let has_local_buf = all_nodes.iter().any(|n| {
matches!(n.op(), Op::Bufferize { opts, .. }
if opts.addrspace == AddrSpace::Local)
});
assert!(has_local_buf, "Should create LOCAL BUFFERIZE for multiple GROUP_REDUCE");
}
#[test]
fn test_group_reduce_only() {
let group_range = create_group_reduce_range(0, 32);
let src = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let reduce = src.reduce(smallvec![group_range], ReduceOp::Add);
let result = expander_rewrite(&reduce);
if let Op::Reduce { .. } = result.op() {
} else {
panic!("Expected REDUCE at top level, got {:?}", result.op());
}
}