use std::{f32::consts::PI, sync::Arc};
use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, BufferizeOpts, ConstValue, Op, ReduceOp, UOp};
use crate::rangeify::{rangeify, run_kernel_split_pipeline, run_rangeify};
fn rangeify_unwrap(uop: Arc<UOp>) -> Arc<UOp> {
match rangeify(uop, None) {
Ok((rangeified, _ctx)) => rangeified,
Err(_) => panic!("rangeify failed"),
}
}
#[test]
fn test_run_rangeify_simple_const() {
let const_val = UOp::native_const(42.0f32);
let result = run_rangeify(const_val);
assert!(result.is_ok(), "rangeify should succeed");
let (rangeified, _ctx) = result.unwrap();
assert!(matches!(rangeified.op(), Op::Const(_)));
}
#[test]
fn test_run_rangeify_detach_removal() {
let x = UOp::native_const(1.0f32);
let detach = x.detach();
let rangeified = rangeify_unwrap(detach);
match rangeified.op() {
Op::Const(_) => {
assert!(Arc::ptr_eq(&rangeified, &x) || matches!(rangeified.op(), Op::Const(_)));
}
_ => {
assert!(!matches!(rangeified.op(), Op::Detach { .. }));
}
}
}
#[test]
fn test_run_rangeify_contiguous_backward_removal() {
let x = UOp::native_const(PI);
let contiguous = x.contiguous_backward();
let rangeified = rangeify_unwrap(contiguous);
match rangeified.op() {
Op::Const(_) => {
assert!(Arc::ptr_eq(&rangeified, &x) || matches!(rangeified.op(), Op::Const(_)));
}
_ => {
assert!(!matches!(rangeified.op(), Op::ContiguousBackward { .. }));
}
}
}
#[test]
fn test_run_rangeify_binary_op() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let rangeified = rangeify_unwrap(add);
assert!(rangeified.dtype() == DType::Float32);
}
#[test]
fn test_run_rangeify_preserves_structure() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let sum = a.try_add(&b).unwrap();
let product = sum.try_mul(&c).unwrap();
let rangeified = rangeify_unwrap(product);
assert_eq!(rangeified.dtype(), DType::Float32);
match rangeified.op() {
Op::Binary { .. } | Op::Const(_) | Op::Bufferize { .. } | Op::Index { .. } => {
}
_ => {}
}
}
#[test]
fn test_kernel_split_pipeline_simple_store() {
let _buffer = UOp::buffer_id(Some(0));
let index = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store = index.store(value);
let (result, _context) = run_kernel_split_pipeline(store);
assert!(result.dtype() == DType::Void || matches!(result.op(), Op::Kernel { .. }));
}
#[test]
fn test_kernel_split_pipeline_with_end() {
let _buffer = UOp::buffer_id(Some(0));
let index = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store = index.store(value);
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let end = store.end(vec![range].into());
let (result, _context) = run_kernel_split_pipeline(end);
assert!(result.dtype() == DType::Void || matches!(result.op(), Op::Kernel { .. }));
}
#[test]
fn test_kernel_split_pipeline_load_store() {
let in_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let _out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let index = UOp::index_const(0);
let load = UOp::load().buffer(in_buf).index(index.clone()).call();
let store = index.store(load);
let (result, _context) = run_kernel_split_pipeline(store);
assert!(result.dtype() == DType::Void || matches!(result.op(), Op::Kernel { .. }));
}
#[test]
fn test_kernel_split_pipeline_multiple_loads() {
let buf1 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let buf2 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let _out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let index = UOp::index_const(0);
let load1 = UOp::load().buffer(buf1).index(index.clone()).call();
let load2 = UOp::load().buffer(buf2).index(index.clone()).call();
let sum = load1.try_add(&load2).unwrap();
let store = index.store(sum);
let (result, _context) = run_kernel_split_pipeline(store);
assert!(result.dtype() == DType::Void || matches!(result.op(), Op::Kernel { .. }));
}
#[test]
fn test_end_to_end_simple_computation() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let sum = a.try_add(&b).unwrap();
let _buffer = UOp::buffer_id(Some(0));
let index = UOp::index_const(0);
let store = index.store(sum);
let rangeified = rangeify_unwrap(store);
let (kernel, _context) = run_kernel_split_pipeline(rangeified);
assert!(kernel.dtype() == DType::Void || matches!(kernel.op(), Op::Kernel { .. }));
}
#[test]
fn test_end_to_end_with_ranges() {
let _buffer = UOp::buffer_id(Some(0));
let index = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store = index.store(value);
let range_end = UOp::index_const(100);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let end = store.end(vec![range].into());
let rangeified = rangeify_unwrap(end);
let (kernel, _context) = run_kernel_split_pipeline(rangeified);
assert!(kernel.dtype() == DType::Void || matches!(kernel.op(), Op::Kernel { .. }));
}
#[test]
fn test_pipeline_idempotent() {
let x = UOp::native_const(1.0f32);
let rangeified1 = rangeify_unwrap(x.clone());
let rangeified2 = rangeify_unwrap(rangeified1);
assert!(rangeified2.dtype() == x.dtype());
}
#[test]
fn test_pipeline_preserves_dtype() {
let test_cases = vec![
(DType::Float32, UOp::native_const(1.0f32)),
(DType::Float64, UOp::native_const(1.0f64)),
(DType::Int32, UOp::native_const(42i32)),
(DType::Int64, UOp::native_const(42i64)),
(DType::Bool, UOp::native_const(true)),
];
for (dtype, value) in test_cases {
let rangeified = rangeify_unwrap(value.clone());
if let Op::Const(_) = rangeified.op() {
assert_eq!(rangeified.dtype(), dtype)
}
}
}
#[test]
fn test_pipeline_handles_noop() {
let noop = UOp::noop();
let rangeified = rangeify_unwrap(noop);
assert!(rangeified.dtype() == DType::Void || matches!(rangeified.op(), Op::Noop));
}
#[test]
fn test_pipeline_complex_nested_structure() {
let mut current = UOp::native_const(1.0f32);
for _ in 0..10 {
let one = UOp::native_const(1.0f32);
current = current.try_add(&one).unwrap();
}
let rangeified = rangeify_unwrap(current);
assert_eq!(rangeified.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_wide_tree() {
let mut operands = Vec::new();
for i in 0..20 {
operands.push(UOp::native_const(i as f32));
}
let mut sum = operands[0].clone();
for operand in &operands[1..] {
sum = sum.try_add(operand).unwrap();
}
let rangeified = rangeify_unwrap(sum);
assert_eq!(rangeified.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_applies_early_rewrites_first() {
let x = UOp::native_const(1.0f32);
let detach = x.detach();
let rangeified = rangeify_unwrap(detach);
if let Op::Detach { .. } = rangeified.op() {
panic!("DETACH should have been removed by early_rewrites")
}
}
#[test]
fn test_pipeline_applies_buffer_folding() {
let const_val = UOp::native_const(42.0f32);
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let bufferize = UOp::bufferize(const_val.clone(), vec![range], BufferizeOpts::local());
let rangeified = rangeify_unwrap(bufferize);
match rangeified.op() {
Op::Const(_) | Op::Bufferize { .. } => {} _ => {}
}
}
#[test]
fn test_pipeline_maintains_computation_semantics() {
let a = UOp::native_const(2.0f32);
let b = UOp::native_const(3.0f32);
let c = UOp::native_const(4.0f32);
let product = a.try_mul(&b).unwrap();
let sum = product.try_add(&c).unwrap();
let rangeified = rangeify_unwrap(sum);
assert_eq!(rangeified.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_reduce_unparented_add() {
use morok_ir::ReduceOp;
let const_val = UOp::native_const(5i32);
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.reduce(vec![range].into(), ReduceOp::Add);
let rangeified = rangeify_unwrap(reduce);
match rangeified.op() {
Op::Binary(morok_ir::BinaryOp::Mul, _, _) => {
}
Op::Const(cv_hash) => {
if let ConstValue::Int(n) = cv_hash.0 {
assert_eq!(n, 50, "reduce_unparented should optimize to 50");
}
}
_ => {
}
}
}
#[test]
fn test_pipeline_reduce_unparented_max() {
use morok_ir::ReduceOp;
let const_val = UOp::native_const(42i32);
let range = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.clone().reduce(vec![range].into(), ReduceOp::Max);
let rangeified = rangeify_unwrap(reduce);
match rangeified.op() {
Op::Const(cv_hash) => {
if let ConstValue::Int(n) = cv_hash.0 {
assert_eq!(n, 42, "reduce_unparented MAX should preserve constant");
}
}
_ => {
}
}
}
#[test]
fn test_pipeline_split_reduceop_large_reduction() {
use morok_device::DeviceSpec;
use morok_ir::ReduceOp;
let total_size = 100000;
let buffer = UOp::new_buffer(DeviceSpec::Cpu, total_size, DType::Float32);
let reduce = buffer.try_reduce_axis(ReduceOp::Add, vec![0]).unwrap();
let rangeified = rangeify_unwrap(reduce);
let has_contiguous = rangeified.toposort().iter().any(|node| matches!(node.op(), Op::Contiguous { .. }));
assert!(
has_contiguous,
"split_reduceop should have split large reduction (100000 > 32768 threshold), creating CONTIGUOUS node"
);
assert_eq!(rangeified.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_split_reduceop_below_threshold() {
use morok_device::DeviceSpec;
use morok_ir::ReduceOp;
let total_size = 1000;
let buffer = UOp::new_buffer(DeviceSpec::Cpu, total_size, DType::Float32);
let reduce = buffer.try_reduce_axis(ReduceOp::Add, vec![0]).unwrap();
let rangeified = rangeify_unwrap(reduce);
let has_contiguous = rangeified.toposort().iter().any(|node| matches!(node.op(), Op::Contiguous { .. }));
assert!(!has_contiguous, "split_reduceop should NOT split small reduction (1000 < 32768 threshold)");
assert_eq!(rangeified.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_reduction_optimizations_dont_break_graph() {
use morok_ir::ReduceOp;
let data = UOp::native_const(PI);
let range1 = UOp::range_axis(UOp::index_const(8), AxisId::Renumbered(0), AxisType::Reduce);
let range2 = UOp::range_axis(UOp::index_const(4), AxisId::Renumbered(1), AxisType::Reduce);
let reduce = data.reduce(vec![range1, range2].into(), ReduceOp::Add);
let result = run_rangeify(reduce);
assert!(result.is_ok(), "Pipeline should handle multi-range reduction");
let (rangeified, _ctx) = result.unwrap();
assert_eq!(rangeified.dtype(), DType::Float32);
}
#[test]
fn test_pipeline_reduce_collapse_constant() {
let const_val = UOp::native_const(42i32);
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = const_val.reduce(vec![range].into(), ReduceOp::Add);
let result = rangeify_unwrap(reduce);
assert_ne!(result.dtype(), DType::Void, "Result should have valid dtype after reduce_collapse");
}
#[test]
fn test_pipeline_reduce_collapse_multiple_ranges() {
let const_val = UOp::native_const(PI);
let range1 = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let range2 = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(1), AxisType::Reduce);
let reduce = const_val.reduce(vec![range1, range2].into(), ReduceOp::Add);
let result = rangeify_unwrap(reduce);
assert_eq!(result.dtype(), DType::Float32, "Result should preserve Float32 dtype");
}
#[test]
fn test_pipeline_reduce_collapse_with_algebraic_simplification() {
let x = UOp::native_const(100i32);
let zero = UOp::native_const(0i32);
let x_plus_0 = x.try_add(&zero).unwrap();
let range = UOp::range_axis(UOp::index_const(20), AxisId::Renumbered(0), AxisType::Reduce);
let reduce = x_plus_0.reduce(vec![range].into(), ReduceOp::Add);
let result = rangeify_unwrap(reduce);
assert_eq!(result.dtype(), DType::Int32);
}
#[test]
fn test_pipeline_reduce_collapse_preserves_dependent_reductions() {
let range = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Reduce);
let one = UOp::native_const(1i32);
let range_int = range.cast(DType::Int32);
let src = range_int.try_add(&one).unwrap();
let reduce = src.reduce(vec![range].into(), ReduceOp::Add);
let result = rangeify_unwrap(reduce);
assert_eq!(result.dtype(), DType::Int32);
}
#[test]
fn test_pipeline_reduce_collapse_different_ops() {
let const_val = UOp::native_const(2.5f32);
let range = UOp::range_axis(UOp::index_const(8), AxisId::Renumbered(0), AxisType::Reduce);
let reduce_max = const_val.clone().reduce(vec![range.clone()].into(), ReduceOp::Max);
let result_max = rangeify_unwrap(reduce_max);
assert_eq!(result_max.dtype(), DType::Float32, "MAX reduce should work");
let reduce_min = const_val.reduce(vec![range].into(), ReduceOp::Min);
let result_min = rangeify_unwrap(reduce_min);
assert_eq!(result_min.dtype(), DType::Float32, "MIN reduce should work");
}
#[test]
fn test_pipeline_reduce_collapse_integration_with_unparented() {
let const_val = UOp::native_const(7i32);
let range1 = UOp::range_axis(UOp::index_const(5), AxisId::Renumbered(0), AxisType::Reduce);
let range2 = UOp::range_axis(UOp::index_const(3), AxisId::Renumbered(1), AxisType::Reduce);
let reduce = const_val.reduce(vec![range1, range2].into(), ReduceOp::Add);
let result = rangeify_unwrap(reduce);
assert_eq!(result.dtype(), DType::Int32);
}