use std::{f32::consts::PI, sync::Arc};
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, BinaryOp, Op, ReduceOp, UOp, pattern::RewriteResult};
use crate::rangeify::transforms::reduce_collapse as reduce_collapse_inner;
fn reduce_unparented(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
match crate::rangeify::patterns::pm_reduce_simplify().rewrite(reduce, &mut ()) {
RewriteResult::Rewritten(r) => Some(r),
_ => None,
}
}
fn reduce_collapse(reduce: &Arc<UOp>) -> Option<Arc<UOp>> {
let Op::Reduce { src, ranges, .. } = reduce.op() else {
return None;
};
reduce_collapse_inner(src, ranges)
}
fn has_reduce_op(uop: &Arc<UOp>) -> bool {
uop.toposort().iter().any(|n| matches!(n.op(), Op::Reduce { .. } | Op::ReduceAxis { .. }))
}
fn has_ranges_in_graph(uop: &Arc<UOp>) -> bool {
uop.toposort().iter().any(|n| matches!(n.op(), Op::Range { .. }))
}
#[test]
fn test_reduce_unparented_add_basic() {
let const_val = UOp::native_const(5i32);
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_unparented(&reduce).expect("Should simplify");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Mul, _, _)));
}
#[test]
fn test_reduce_unparented_mul() {
let const_val = UOp::native_const(2i32);
let range = UOp::range_axis(UOp::index_const(3), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.reduce(vec![range].into(), ReduceOp::Mul);
let result = reduce_unparented(&reduce).expect("Should simplify");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Pow, _, _)));
}
#[test]
fn test_reduce_unparented_max() {
let const_val = UOp::native_const(42i32);
let range = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.clone().reduce(vec![range].into(), ReduceOp::Max);
let result = reduce_unparented(&reduce).expect("Should simplify");
assert!(Arc::ptr_eq(&result, &const_val));
}
#[test]
fn test_reduce_unparented_min() {
let const_val = UOp::native_const(42i32);
let range = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.clone().reduce(vec![range].into(), ReduceOp::Min);
let result = reduce_unparented(&reduce).expect("Should simplify");
assert!(Arc::ptr_eq(&result, &const_val));
}
#[test]
fn test_reduce_unparented_all_parented() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = Arc::clone(&range).reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_unparented(&reduce);
assert!(result.is_none());
}
#[test]
fn test_reduce_unparented_mixed_ranges() {
let range_0 = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let range_1 = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(1), AxisType::Reduce);
let x = UOp::native_const(3i32);
let src = x.try_add(&range_0.cast(DType::Int32)).unwrap();
let reduce = src.reduce(vec![range_0.clone(), range_1].into(), ReduceOp::Add);
let result = reduce_unparented(&reduce).expect("Should simplify");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Mul, _, _)));
if let Op::Binary(_, inner, _) = result.op() {
if let Op::Reduce { ranges, .. } = inner.op() {
assert_eq!(ranges.len(), 1);
assert!(Arc::ptr_eq(&ranges[0], &range_0));
} else {
panic!("Expected REDUCE in inner op, got {:?}", inner.op());
}
}
}
#[test]
fn test_reduce_unparented_multiple_unparented() {
let const_val = UOp::native_const(5i32);
let range_0 = UOp::range_axis(UOp::index_const(3), AxisId::Renumbered(0), AxisType::Reduce);
let range_1 = UOp::range_axis(UOp::index_const(4), AxisId::Renumbered(1), AxisType::Reduce);
let reduce = const_val.reduce(vec![range_0, range_1].into(), ReduceOp::Add);
let result = reduce_unparented(&reduce).expect("Should simplify");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Mul, _, _)));
if let Op::Binary(_, inner, _) = result.op() {
assert!(matches!(inner.op(), Op::Binary(BinaryOp::Mul, _, _)));
}
}
#[test]
fn test_reduce_unparented_non_reduce_returns_none() {
let const_op = UOp::native_const(1.0f32);
let result = reduce_unparented(&const_op);
assert!(result.is_none());
}
#[test]
fn test_reduce_collapse_basic() {
let const_val = UOp::native_const(5i32);
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.clone().reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_collapse(&reduce).expect("reduce_collapse should succeed on constant");
assert!(!has_ranges_in_graph(&result), "Result should have no range dependencies");
assert!(!has_reduce_op(&result), "Result should not contain REDUCE operations");
assert_eq!(result.dtype(), const_val.dtype(), "Should preserve dtype");
}
#[test]
fn test_reduce_collapse_with_range_dependency() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let one = UOp::native_const(1i32);
let range_int = range.cast(DType::Int32);
let src = range_int.try_add(&one).unwrap();
let reduce = src.reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_collapse(&reduce);
assert!(result.is_none(), "reduce_collapse should return None when range dependency can't be eliminated");
}
#[test]
fn test_reduce_collapse_non_reduce_returns_none() {
let const_op = UOp::native_const(1.0f32);
let result = reduce_collapse(&const_op);
assert!(result.is_none());
}
#[test]
fn test_reduce_collapse_empty_ranges() {
let const_val = UOp::native_const(5i32);
let reduce = const_val.reduce(vec![].into(), ReduceOp::Add);
let result = reduce_collapse(&reduce);
assert!(result.is_none(), "reduce_collapse should return None for empty ranges");
}
#[test]
fn test_reduce_collapse_multiple_ranges_all_independent() {
let const_val = UOp::native_const(5i32);
let range1 = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let range2 = UOp::range_axis(UOp::index_const(20), AxisId::Renumbered(1), AxisType::Reduce);
let reduce = const_val.clone().reduce(vec![range1, range2].into(), ReduceOp::Add);
let result = reduce_collapse(&reduce);
assert!(result.is_some(), "reduce_collapse should succeed with multiple independent ranges");
if let Some(res) = result {
assert!(crate::rangeify::indexing::no_range(&res), "Result should have no range dependencies");
}
}
#[test]
fn test_reduce_collapse_algebraic_simplification() {
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let x_plus_0 = x.try_add(&zero).unwrap();
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = x_plus_0.reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_collapse(&reduce).expect("reduce_collapse should succeed after x+0 simplification");
assert!(!has_ranges_in_graph(&result), "x+0 simplification should eliminate ranges");
assert!(!has_reduce_op(&result), "Result should not contain REDUCE");
let has_add = result.toposort().iter().any(|n| matches!(n.op(), Op::Binary(BinaryOp::Add, _, _)));
assert!(!has_add, "x+0 should be simplified away");
}
#[test]
fn test_reduce_collapse_multiplication_by_one() {
let x = UOp::native_const(PI);
let one = UOp::native_const(1.0f32);
let x_times_1 = x.try_mul(&one).unwrap();
let range = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = x_times_1.reduce(vec![range].into(), ReduceOp::Mul);
let result = reduce_collapse(&reduce).expect("reduce_collapse should succeed after x*1 simplification");
assert!(!has_ranges_in_graph(&result), "x*1 simplification should eliminate ranges");
assert!(!has_reduce_op(&result), "Result should not contain REDUCE");
let has_mul = result.toposort().iter().any(|n| matches!(n.op(), Op::Binary(BinaryOp::Mul, _, _)));
assert!(!has_mul, "x*1 should be simplified away");
}
#[test]
fn test_reduce_collapse_preserves_dtype() {
let const_val = UOp::native_const(2.5f64);
let range = UOp::range_axis(UOp::index_const(100), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.clone().reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_collapse(&reduce);
assert!(result.is_some(), "reduce_collapse should succeed");
if let Some(res) = result {
assert_eq!(res.dtype(), const_val.dtype(), "reduce_collapse should preserve dtype");
}
}
#[test]
fn test_reduce_collapse_different_reduce_ops() {
let const_val = UOp::native_const(10i32);
let range = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let reduce_add = const_val.clone().reduce(vec![range.clone()].into(), ReduceOp::Add);
assert!(reduce_collapse(&reduce_add).is_some(), "reduce_collapse should work with ReduceOp::Add");
let reduce_mul = const_val.clone().reduce(vec![range.clone()].into(), ReduceOp::Mul);
assert!(reduce_collapse(&reduce_mul).is_some(), "reduce_collapse should work with ReduceOp::Mul");
let reduce_max = const_val.clone().reduce(vec![range.clone()].into(), ReduceOp::Max);
assert!(reduce_collapse(&reduce_max).is_some(), "reduce_collapse should work with ReduceOp::Max");
let reduce_min = const_val.reduce(vec![range].into(), ReduceOp::Min);
assert!(reduce_collapse(&reduce_min).is_some(), "reduce_collapse should work with ReduceOp::Min");
}
#[test]
fn test_no_range_with_ranges() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let const_5 = UOp::native_const(5i32);
let sum = range.cast(DType::Int32).try_add(&const_5).unwrap();
assert!(!crate::rangeify::indexing::no_range(&sum));
}
#[test]
fn test_no_range_without_ranges() {
let const_val = UOp::native_const(42i32);
assert!(crate::rangeify::indexing::no_range(&const_val));
let a = UOp::native_const(10i32);
let b = UOp::native_const(20i32);
let sum = a.try_add(&b).unwrap();
assert!(crate::rangeify::indexing::no_range(&sum));
}
#[test]
fn test_range_size_extraction_constant() {
let range = UOp::range_axis(UOp::index_const(100), AxisId::Renumbered(0), AxisType::Loop);
assert_eq!(crate::rangeify::indexing::range_size_as_i64(&range), Some(100));
let range_42 = UOp::range_axis(UOp::index_const(42), AxisId::Renumbered(1), AxisType::Reduce);
assert_eq!(crate::rangeify::indexing::range_size_as_i64(&range_42), Some(42));
}
#[test]
fn test_range_size_extraction_symbolic() {
let symbolic_var = UOp::define_var("N".to_string(), 0, 1000);
let range = UOp::range_axis(symbolic_var, AxisId::Renumbered(0), AxisType::Loop);
assert_eq!(crate::rangeify::indexing::range_size_as_i64(&range), None);
}
#[test]
fn test_range_size_extraction_non_range() {
let const_op = UOp::native_const(100i32);
assert_eq!(crate::rangeify::indexing::range_size_as_i64(&const_op), None);
let a = UOp::native_const(10i32);
let b = UOp::native_const(20i32);
let sum = a.try_add(&b).unwrap();
assert_eq!(crate::rangeify::indexing::range_size_as_i64(&sum), None);
}
#[test]
fn test_reduce_mul_chain_simple_const() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let three = UOp::native_const(3i32);
let src = range.cast(DType::Int32).mul(&three);
let reduce = src.reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_unparented(&reduce).expect("Should factor const out of reduce");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Mul, _, _)));
if let Op::Binary(BinaryOp::Mul, inner, _factor) = result.op() {
assert!(matches!(inner.op(), Op::Reduce { .. }));
}
}
#[test]
fn test_reduce_mul_chain_no_outside_factors() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let range_int = range.cast(DType::Int32);
let src = range_int.mul(&range_int);
let reduce = src.reduce(vec![range].into(), ReduceOp::Add);
let _result = reduce_unparented(&reduce);
}
#[test]
fn test_reduce_mul_chain_multiple_factors() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let a = UOp::native_const(2i32);
let b = UOp::native_const(5i32);
let range_int = range.cast(DType::Int32);
let src = a.mul(&range_int).mul(&b);
let reduce = src.reduce(vec![range].into(), ReduceOp::Add);
let result = reduce_unparented(&reduce).expect("Should factor constants out");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Mul, _, _)));
}
#[test]
fn test_reduce_mul_chain_max_positive_factor() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let three = UOp::native_const(3i32);
let range_int = range.cast(DType::Int32);
let src = range_int.mul(&three);
let reduce = src.reduce(vec![range].into(), ReduceOp::Max);
let result = reduce_unparented(&reduce).expect("Should factor positive const out of MAX reduce");
assert!(matches!(result.op(), Op::Binary(BinaryOp::Mul, _, _)));
}
#[test]
fn test_reduce_mul_chain_max_negative_factor_stays() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let neg_one = UOp::native_const(-1i32);
let range_int = range.cast(DType::Int32);
let src = range_int.mul(&neg_one);
let reduce = src.reduce(vec![range].into(), ReduceOp::Max);
let result = reduce_unparented(&reduce);
if let Some(ref res) = result {
if let Op::Binary(BinaryOp::Mul, _inner, factor) = res.op() {
if let Op::Const(c) = factor.op() {
assert!(
c.0 != morok_ir::ConstValue::Int(-1),
"Negative factor should not be factored out of MAX reduce"
);
}
}
}
}
#[test]
fn test_reduce_mul_chain_single_factor_no_op() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = range.cast(DType::Int32).reduce(vec![range].into(), ReduceOp::Add);
let _result = reduce_unparented(&reduce);
}