use morok_dtype::DType;
use morok_ir::UOp;
use morok_ir::types::{ConstValue, ReduceOp};
use smallvec::smallvec;
use std::sync::Arc;
use crate::rewrite::graph_rewrite;
use super::helpers::{assert_const_value, assert_end_range_count, assert_end_unwrapped, get_matcher};
#[test]
fn test_range_zero_to_const() {
let zero = UOp::native_const(0i32);
let range = UOp::range(zero, 0);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, range, &mut ());
assert_const_value(&result, ConstValue::Int(0));
}
#[test]
fn test_range_negative_to_const() {
let neg_five = UOp::native_const(-5i32);
let range = UOp::range(neg_five, 0);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, range, &mut ());
assert_const_value(&result, ConstValue::Int(0));
}
#[test]
fn test_end_all_dead_ranges_unwrapped() {
let store = UOp::noop();
let dead_range = UOp::range_const(0, 0);
let end = Arc::clone(&store).end(smallvec![dead_range]);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, end, &mut ());
let unwrapped = assert_end_unwrapped(&result);
assert!(Arc::ptr_eq(&unwrapped, &store), "Expected END to unwrap to original store");
}
#[test]
fn test_end_partial_dead_ranges_removed() {
let store = UOp::noop();
let live1 = UOp::range_const(10, 0);
let dead = UOp::range_const(0, 0);
let live2 = UOp::range_const(5, 0);
let end = Arc::clone(&store).end(smallvec![Arc::clone(&live1), dead, Arc::clone(&live2)]);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, end, &mut ());
let (computation, ranges) = assert_end_range_count(&result, 2);
assert!(Arc::ptr_eq(&computation, &store), "Expected same computation");
assert!(Arc::ptr_eq(&ranges[0], &live1), "Expected first live range preserved");
assert!(Arc::ptr_eq(&ranges[1], &live2), "Expected second live range preserved");
}
#[test]
fn test_reduce_add_empty_to_zero() {
let src = UOp::var("x", DType::Int32, 0, 100);
let dead_range = UOp::range_const(0, 0);
let reduce = src.reduce(smallvec![dead_range], ReduceOp::Add);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, reduce, &mut ());
assert_const_value(&result, ConstValue::Int(0));
}
#[test]
fn test_reduce_mul_empty_to_one() {
let src = UOp::var("x", DType::Int32, 0, 100);
let dead_range = UOp::range_const(-5, 0);
let reduce = src.reduce(smallvec![dead_range], ReduceOp::Mul);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, reduce, &mut ());
assert_const_value(&result, ConstValue::Int(1));
}
#[test]
fn test_reduce_max_empty_to_min() {
let src = UOp::var("x", DType::Int32, 0, 100);
let dead_range = UOp::range_const(0, 0);
let reduce = src.reduce(smallvec![dead_range], ReduceOp::Max);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, reduce, &mut ());
assert_const_value(&result, ConstValue::Int(i32::MIN as i64));
}
#[test]
fn test_range_symbolic_dead() {
let size = UOp::var("size", DType::Int32, 0, 5);
let ten = UOp::native_const(10i32);
let count = size.try_sub(&ten).expect("SUB should succeed");
let range = UOp::range(count, 0);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, range, &mut ());
assert_const_value(&result, ConstValue::Int(0));
}
#[test]
fn test_range_boundary_vmax_zero() {
let neg_ten = UOp::native_const(-10i32);
let zero = UOp::native_const(0i32);
let max_val = neg_ten.try_max(&zero).unwrap();
let range = UOp::range(max_val, 0);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, range, &mut ());
assert_const_value(&result, ConstValue::Int(0));
}
#[test]
fn test_end_empty_ranges_returns_self() {
let store = UOp::noop();
let end = Arc::clone(&store).end(smallvec![]);
assert!(Arc::ptr_eq(&end, &store), "end(empty) should return self");
}
#[test]
fn test_end_multiple_dead_ranges_unwrapped() {
let store = UOp::noop();
let dead1 = UOp::range_const(0, 0);
let dead2 = UOp::range_const(-5, 0);
let end = Arc::clone(&store).end(smallvec![dead1, dead2]);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, end, &mut ());
let unwrapped = assert_end_unwrapped(&result);
assert!(Arc::ptr_eq(&unwrapped, &store), "Expected END to unwrap to original store");
}
#[test]
fn test_reduce_multiple_dead_ranges() {
let src = UOp::var("x", DType::Int32, 0, 100);
let dead1 = UOp::range_const(0, 0);
let dead2 = UOp::range_const(-5, 0);
let reduce = src.reduce(smallvec![dead1, dead2], ReduceOp::Add);
let matcher = get_matcher();
let result = graph_rewrite(&matcher, reduce, &mut ());
assert_const_value(&result, ConstValue::Int(0));
}