use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, ConstValue, Op, ReduceOp, UOp};
use smallvec::smallvec;
use crate::pattern::RewriteResult;
use crate::rangeify::patterns::pm_load_collapse;
fn test_range(end: i64) -> Arc<UOp> {
UOp::range_axis(UOp::index_const(end), AxisId::Renumbered(0), AxisType::Reduce)
}
fn reduce_add(src: Arc<UOp>, range: Arc<UOp>) -> Arc<UOp> {
src.reduce(smallvec![range], ReduceOp::Add)
}
#[test]
fn test_bounded_sum_below() {
let range = test_range(10);
let cut = UOp::index_const(5);
let cond = range.try_cmplt(&cut).expect("cmplt");
let val = UOp::native_const(1.0f32);
let zero = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let where_expr = UOp::try_where(cond, val.clone(), zero).expect("where");
let reduce = reduce_add(where_expr, range.clone());
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce, &mut ());
if let RewriteResult::Rewritten(collapsed) = result {
assert!(!matches!(collapsed.op(), Op::Reduce { .. }), "Should have eliminated REDUCE");
} else {
}
}
#[test]
fn test_bounded_sum_above() {
let range = test_range(10);
let cut = UOp::index_const(3);
let cond = range.try_cmplt(&cut).expect("cmplt");
let val = UOp::native_const(1.0f32);
let zero = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let where_expr = UOp::try_where(cond, zero, val.clone()).expect("where");
let reduce = reduce_add(where_expr, range.clone());
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce, &mut ());
if let RewriteResult::Rewritten(collapsed) = result {
assert!(!matches!(collapsed.op(), Op::Reduce { .. }), "Should have eliminated REDUCE");
}
}
#[test]
fn test_nested_reduce_collapsed_by_full_algorithm() {
let inner_range = test_range(5);
let outer_range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(1), AxisType::Reduce);
let val = UOp::native_const(1.0f32);
let inner_reduce = reduce_add(val, inner_range);
let outer_reduce = reduce_add(inner_reduce, outer_range);
let matcher = pm_load_collapse();
let result = matcher.rewrite(&outer_reduce, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)), "Full reduce_load_collapse should collapse nested reduces");
}
#[test]
fn test_non_add_reduce_not_collapsed() {
let range = test_range(10);
let val = UOp::native_const(1.0f32);
let reduce_mul = val.reduce(smallvec![range], ReduceOp::Mul);
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce_mul, &mut ());
assert!(matches!(result, RewriteResult::NoMatch), "Mul reduces should not be handled by load collapse");
}
#[test]
fn test_arithmetic_lifting_add() {
let x = UOp::index_const(5); let y = UOp::index_const(3);
let c = UOp::index_const(10);
let add = x.try_add(&y).expect("add");
let cond = add.try_cmplt(&c).expect("cmplt");
let matcher = pm_load_collapse();
let result = matcher.rewrite(&cond, &mut ());
assert!(
matches!(result, RewriteResult::NoMatch | RewriteResult::Rewritten(_)),
"Arithmetic lifting should be attempted"
);
}
#[test]
fn test_two_sided_bounds() {
let range = test_range(10);
let lower = UOp::index_const(2);
let upper = UOp::index_const(7);
let lt_lower = range.try_cmplt(&lower).expect("cmplt");
let ge_lower = lt_lower.not();
let lt_upper = range.try_cmplt(&upper).expect("cmplt");
let cond = ge_lower.and_(<_upper);
let val = UOp::native_const(1.0f32);
let zero = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let where_expr = UOp::try_where(cond, val.clone(), zero).expect("where");
let reduce = reduce_add(where_expr, range.clone());
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce, &mut ());
if let RewriteResult::Rewritten(collapsed) = result {
assert!(!matches!(collapsed.op(), Op::Reduce { .. }), "Should have eliminated REDUCE for two-sided bounds");
}
}
#[test]
fn test_mul_casted_bool() {
let gate = UOp::const_(DType::Bool, ConstValue::Int(1)); let gate_cast = gate.cast(DType::Float32);
let x = UOp::native_const(5.0f32);
let mul = x.try_mul(&gate_cast).expect("mul");
let matcher = pm_load_collapse();
let result = matcher.rewrite(&mul, &mut ());
if let RewriteResult::Rewritten(rewritten) = result {
assert!(matches!(rewritten.op(), Op::Ternary(morok_ir::TernaryOp::Where, ..)), "Should convert to WHERE");
}
}
#[test]
fn test_ne_lifting() {
let x = UOp::index_const(5);
let y = UOp::index_const(3);
let c = UOp::index_const(10);
let add = x.try_add(&y).expect("add");
let ne = add.try_cmpne(&c).expect("cmpne");
let matcher = pm_load_collapse();
let result = matcher.rewrite(&ne, &mut ());
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(morok_ir::BinaryOp::Ne, lhs, rhs) = rewritten.op() {
assert_eq!(lhs.dtype(), DType::Index, "LHS should be Index dtype");
assert_eq!(rhs.dtype(), DType::Index, "RHS should be Index dtype");
}
}
}
#[test]
fn test_two_sided_bounds_lower_gt_upper() {
let range = test_range(10);
let lower = UOp::index_const(7); let upper = UOp::index_const(2);
let lt_lower = range.try_cmplt(&lower).expect("cmplt");
let ge_lower = lt_lower.not();
let lt_upper = range.try_cmplt(&upper).expect("cmplt");
let cond = ge_lower.and_(<_upper);
let val = UOp::native_const(1.0f32);
let zero = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let where_expr = UOp::try_where(cond, val.clone(), zero).expect("where");
let reduce = reduce_add(where_expr, range.clone());
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce, &mut ());
if let RewriteResult::Rewritten(collapsed) = result {
assert!(!matches!(collapsed.op(), Op::Reduce { .. }), "Should have eliminated REDUCE");
}
}
#[test]
fn test_two_sided_bounds_ge_form() {
let range = test_range(10);
let lower = UOp::index_const(3);
let upper = UOp::index_const(8);
let ge_lower = range.try_cmpge(&lower).expect("cmpge");
let lt_upper = range.try_cmplt(&upper).expect("cmplt");
let cond = ge_lower.and_(<_upper);
let val = UOp::native_const(1.0f32);
let zero = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let where_expr = UOp::try_where(cond, val.clone(), zero).expect("where");
let reduce = reduce_add(where_expr, range.clone());
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce, &mut ());
if let RewriteResult::Rewritten(collapsed) = result {
assert!(!matches!(collapsed.op(), Op::Reduce { .. }), "Should have eliminated REDUCE with GE form");
}
}
#[test]
fn test_two_sided_bounds_at_range_edges() {
let range = test_range(10);
let lower = UOp::index_const(0);
let upper = UOp::index_const(10);
let lt_lower = range.try_cmplt(&lower).expect("cmplt");
let ge_lower = lt_lower.not();
let lt_upper = range.try_cmplt(&upper).expect("cmplt");
let cond = ge_lower.and_(<_upper);
let val = UOp::native_const(1.0f32);
let zero = UOp::const_(DType::Float32, ConstValue::Float(0.0));
let where_expr = UOp::try_where(cond, val.clone(), zero).expect("where");
let reduce = reduce_add(where_expr, range.clone());
let matcher = pm_load_collapse();
let result = matcher.rewrite(&reduce, &mut ());
if let RewriteResult::Rewritten(collapsed) = result {
assert!(!matches!(collapsed.op(), Op::Reduce { .. }), "Should have eliminated REDUCE for full range");
}
}