use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{BinaryOp, Op, ReduceOp, UOp};
use smallvec::smallvec;
use super::helpers::*;
#[test]
fn test_reduce_scalar_add() {
let range = create_range_reduce(16, 0);
let src = create_float_const(1.0);
let reduce = create_reduce(src, vec![range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }), "Should transform REDUCE to accumulator pattern");
assert!(count_define_regs(&result) > 0, "Should contain DEFINE_REG");
assert!(count_ends(&result) > 0, "Should contain END");
}
#[test]
fn test_reduce_scalar_mul() {
let range = create_range_reduce(8, 0);
let src = create_float_const(2.0);
let reduce = create_reduce(src, vec![range], ReduceOp::Mul);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }), "Should transform REDUCE");
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_reduce_scalar_max() {
let range = create_range_reduce(32, 0);
let src = create_float_const(0.0);
let reduce = create_reduce(src, vec![range], ReduceOp::Max);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }), "Should transform REDUCE");
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_reduce_scalar_min() {
let range = create_range_reduce(32, 0);
let src = create_float_const(100.0);
let reduce = create_reduce(src, vec![range], ReduceOp::Min);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }), "Should transform REDUCE");
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_reduce_vector_to_scalar() {
let range = create_range_reduce(16, 0);
let src = create_vector_float_iota(4);
let reduce = src.reduce(smallvec![range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }), "Should transform REDUCE");
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_horizontal_reduce_no_ranges() {
let src = create_vector_float_iota(4);
let reduce = src.reduce(smallvec![], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }), "Should transform REDUCE");
assert_eq!(count_define_regs(&result), 0, "Should not have DEFINE_REG for horizontal-only reduce");
}
#[test]
fn test_horizontal_reduce_identity() {
let range = create_range_reduce(8, 0);
let src = create_vector_float_iota(4);
let reduce =
UOp::new(Op::Reduce { src, ranges: smallvec![range], reduce_op: ReduceOp::Add }, DType::Float32.vec(4));
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert_eq!(result.dtype().vcount(), 4);
}
#[test]
fn test_horizontal_reduce_16_to_4() {
let range = create_range_reduce(8, 0);
let elements: smallvec::SmallVec<[Arc<UOp>; 4]> =
(0..16).map(|i| UOp::const_(DType::Float32, ConstValue::Float(i as f64))).collect();
let src = UOp::vectorize(elements);
let reduce =
UOp::new(Op::Reduce { src, ranges: smallvec![range], reduce_op: ReduceOp::Add }, DType::Float32.vec(4));
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert_eq!(result.dtype().vcount(), 4);
}
#[test]
fn test_reduce_empty_ranges() {
let src = create_vector_float_iota(4);
let reduce = src.reduce(smallvec![], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
}
#[test]
fn test_reduce_single_element() {
let range = create_range_reduce(1, 0);
let src = create_float_const(42.0);
let reduce = create_reduce(src, vec![range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert_eq!(result.dtype().vcount(), 1);
}
#[test]
fn test_reduce_multiple_ranges() {
let range1 = create_range_reduce(8, 0);
let range2 = create_range_reduce(4, 1);
let src = create_float_const(1.0);
let reduce = create_reduce(src.clone(), vec![range1, range2], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
assert!(count_ends(&result) > 0);
}
#[test]
fn test_input_ranges_include_thread() {
let thread_range = create_range_thread(32, 0);
let reduce_range = create_range_reduce(16, 1);
let src = thread_range.cast(DType::Float32);
let reduce = create_reduce(src, vec![reduce_range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_input_ranges_include_global() {
let global_range = create_range_global(64, 0);
let reduce_range = create_range_reduce(16, 1);
let src = global_range.cast(DType::Float32);
let reduce = create_reduce(src, vec![reduce_range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_input_ranges_include_local() {
let local_range = create_range_local(16, 0);
let reduce_range = create_range_reduce(8, 1);
let src = local_range.cast(DType::Float32);
let reduce = create_reduce(src, vec![reduce_range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_input_ranges_include_loop() {
let loop_range = create_range_loop(8, 0);
let reduce_range = create_range_reduce(16, 1);
let src = loop_range.cast(DType::Float32);
let reduce = create_reduce(src, vec![reduce_range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_input_ranges_exclude_reduce_range() {
let reduce_range = create_range_reduce(16, 0);
let src = reduce_range.clone().cast(DType::Float32);
let reduce = create_reduce(src, vec![reduce_range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_input_ranges_mixed_axis_types() {
let global_range = create_range_global(64, 0);
let thread_range = create_range_thread(32, 1);
let loop_range = create_range_loop(8, 2);
let reduce_range = create_range_reduce(16, 3);
let src = UOp::new(
Op::Binary(
BinaryOp::Add,
UOp::new(
Op::Binary(BinaryOp::Add, global_range.cast(DType::Float32), thread_range.cast(DType::Float32)),
DType::Float32,
),
loop_range.cast(DType::Float32),
),
DType::Float32,
);
let reduce = create_reduce(src, vec![reduce_range], ReduceOp::Add);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert!(count_define_regs(&result) > 0);
}
#[test]
fn test_reduce_in_full_pipeline() {
use crate::devectorize::pm_reduce;
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::gep_pushing_patterns;
use morok_dtype::{AddrSpace, DeviceSpec};
let reduce_range = create_range_reduce(32, 0);
let buffer_dtype = DType::Float32.ptr(Some(1024), AddrSpace::Global);
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 1024, buffer_dtype.clone());
let define = UOp::param(0, 1024, buffer_dtype, None);
let idx = UOp::index().buffer(define).indices(vec![reduce_range.clone()]).call().unwrap();
let load = UOp::load().buffer(buffer.clone()).index(idx).call();
let reduce = load.reduce(smallvec![reduce_range], ReduceOp::Add);
let combined = pm_reduce() + gep_pushing_patterns().with_context();
let mut ctx = crate::devectorize::ReduceContext::default();
let result = graph_rewrite(&combined, reduce, &mut ctx);
assert!(!matches!(result.op(), Op::Reduce { .. }), "REDUCE should be transformed");
assert!(count_define_regs(&result) > 0, "Should have DEFINE_REG for accumulator");
}
#[test]
fn test_reduce_with_vectorized_source() {
let reduce_range = create_range_reduce(16, 0);
let elements: smallvec::SmallVec<[Arc<UOp>; 4]> =
(0..4).map(|i| UOp::const_(DType::Float32, ConstValue::Float(i as f64))).collect();
let vectorized = UOp::vectorize(elements);
let reduce = UOp::new(
Op::Reduce { src: vectorized, ranges: smallvec![reduce_range], reduce_op: ReduceOp::Add },
DType::Float32.vec(4),
);
let result = apply_pm_reduce(&reduce);
assert!(!matches!(result.op(), Op::Reduce { .. }));
assert_eq!(result.dtype().vcount(), 4);
}