use std::sync::Arc;
use morok_dtype::{AddrSpace, DType};
use morok_ir::{AxisId, AxisType, ConstValue, Op, UOp};
use smallvec::smallvec;
use crate::rangeify::{KernelContext, patterns::to_param_patterns};
fn apply_patterns(uop: &Arc<UOp>, ctx: &mut KernelContext) -> Option<Arc<UOp>> {
let matcher = to_param_patterns();
match matcher.rewrite(uop, ctx) {
morok_ir::pattern::RewriteResult::Rewritten(result) => Some(result),
_ => None,
}
}
#[test]
fn test_debuf_global() {
let mut ctx = KernelContext::new();
let unique = UOp::buffer_id(Some(0));
let device = UOp::device(morok_device::DeviceSpec::Cpu);
let buffer = UOp::new(Op::Buffer { unique, device, size: 100 }, DType::Float32);
let result = apply_patterns(&buffer, &mut ctx);
let op = result.expect("Expected Some result");
assert!(matches!(op.op(), Op::Param { device: None, .. }));
assert_eq!(ctx.global_counter, 1);
}
#[test]
fn test_unbind_kernel() {
let mut ctx = KernelContext::new();
let var = UOp::new(Op::DefineVar { name: "x".to_string(), min_val: 0, max_val: 10 }, DType::Index);
let value = UOp::index_const(5);
let bind = var.bind(value);
let result = apply_patterns(&bind, &mut ctx);
let op = result.expect("Expected Some result");
assert!(matches!(op.op(), Op::DefineVar { .. }));
assert!(ctx.vars.contains_key("x"));
let (_, bound_val) = ctx.vars.get("x").unwrap();
assert_eq!(*bound_val, Some(5));
}
#[test]
fn test_renumber_range() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Unrenumbered(5), AxisType::Reduce);
let result = apply_patterns(&range, &mut ctx);
let op = result.expect("Expected Some result");
if let Op::Range { axis_id, .. } = op.op() {
assert_eq!(*axis_id, AxisId::Renumbered(0));
} else {
panic!("Expected RANGE operation");
}
}
#[test]
fn test_renumber_range_loop_no_bind() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Unrenumbered(5), AxisType::Loop);
let result = apply_patterns(&range, &mut ctx);
let op = result.expect("Expected Some result");
if let Op::Range { axis_id, axis_type, .. } = op.op() {
assert_eq!(*axis_id, AxisId::Renumbered(0));
assert_eq!(*axis_type, AxisType::Loop);
} else {
panic!("Expected RANGE operation for LOOP axis, got {:?}", op.op());
}
}
#[test]
fn test_renumber_range_already_numbered() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(5), AxisType::Loop);
let result = apply_patterns(&range, &mut ctx);
assert!(result.is_none(), "Already-numbered range should not be renumbered");
}
#[test]
fn test_remove_zero_range() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(0);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let result = apply_patterns(&range, &mut ctx);
let op = result.expect("Expected Some result");
assert!(matches!(op.op(), Op::Const(_)));
}
#[test]
fn test_cleanup_const_with_sources() {
let mut ctx = KernelContext::new();
let const_op = UOp::native_const(42i32);
let result = apply_patterns(&const_op, &mut ctx);
assert!(result.is_none());
}
#[test]
fn test_handle_after() {
let mut ctx = KernelContext::new();
let buffer = UOp::buffer_id(Some(0));
let store = UOp::noop();
let after = buffer.after(smallvec::smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
let op = result.expect("Expected Some result");
assert!(matches!(op.op(), Op::Unique(_)));
assert!(ctx.has_buffer(&buffer));
assert!(Arc::ptr_eq(ctx.get_buffer(&buffer).unwrap(), &after));
}
#[test]
fn test_debuf_counter_increment() {
let mut ctx = KernelContext::new();
let unique1 = UOp::buffer_id(Some(1));
let device1 = UOp::device(morok_device::DeviceSpec::Cpu);
let buffer1 = UOp::new(Op::Buffer { unique: unique1, device: device1, size: 100 }, DType::Float32);
let unique2 = UOp::buffer_id(Some(2));
let device2 = UOp::device(morok_device::DeviceSpec::Cpu);
let buffer2 = UOp::new(Op::Buffer { unique: unique2, device: device2, size: 200 }, DType::Float32);
let result1 = apply_patterns(&buffer1, &mut ctx);
assert!(result1.is_some());
assert_eq!(ctx.global_counter, 1);
let result2 = apply_patterns(&buffer2, &mut ctx);
assert!(result2.is_some());
assert_eq!(ctx.global_counter, 2);
assert!(ctx.has_buffer(&buffer1));
assert!(ctx.has_buffer(&buffer2));
}
#[test]
fn test_debuf_buffer_mapping() {
let mut ctx = KernelContext::new();
let unique = UOp::buffer_id(Some(0));
let device = UOp::device(morok_device::DeviceSpec::Cpu);
let buffer = UOp::new(Op::Buffer { unique, device, size: 100 }, DType::Float32);
let result = apply_patterns(&buffer, &mut ctx);
assert!(result.is_some());
let param = result.unwrap();
assert!(matches!(param.op(), Op::Param { slot: 0, device: None, .. }));
assert!(ctx.has_buffer(&buffer));
let mapped = ctx.get_buffer(&buffer).unwrap();
assert!(Arc::ptr_eq(mapped, ¶m));
}
#[test]
fn test_handle_after_mstack_unwrap() {
let mut ctx = KernelContext::new();
let buf1 = UOp::buffer_id(Some(1));
let buf2 = UOp::buffer_id(Some(2));
let mstack = UOp::new(Op::MStack { buffers: smallvec![buf1.clone(), buf2] }, buf1.dtype());
let store = UOp::noop();
let after = mstack.after(smallvec::smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
let op = result.expect("Expected Some result");
assert!(matches!(op.op(), Op::Unique(_)));
assert!(Arc::ptr_eq(&op, &buf1));
assert!(Arc::ptr_eq(ctx.get_buffer(&buf1).unwrap(), &after));
}
#[test]
fn test_handle_after_mselect_unwrap() {
let mut ctx = KernelContext::new();
let buffer = UOp::buffer_id(Some(1));
let mselect = UOp::new(Op::MSelect { buffer: buffer.clone(), device_index: 0 }, buffer.dtype());
let store = UOp::noop();
let after = mselect.after(smallvec::smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
let op = result.expect("Expected Some result");
assert!(Arc::ptr_eq(&op, &buffer));
assert!(Arc::ptr_eq(ctx.get_buffer(&buffer).unwrap(), &after));
}
#[test]
fn test_renumber_range_different_axis_types() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(10);
for (i, axis_type) in [AxisType::Loop, AxisType::Reduce, AxisType::Outer].iter().enumerate() {
let range = UOp::range_axis(end.clone(), AxisId::Unrenumbered(i), *axis_type);
let result = apply_patterns(&range, &mut ctx);
if let Some(r) = result {
if let Op::Range { axis_type: new_type, .. } = r.op() {
assert_eq!(*new_type, *axis_type);
} else {
panic!("Expected Range for {:?}, got {:?}", axis_type, r.op());
}
} else {
panic!("Expected Some result for {:?}", axis_type);
}
}
}
#[test]
fn test_renumber_range_no_change_if_same() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(10);
let range1 = UOp::range_axis(end.clone(), AxisId::Renumbered(5), AxisType::Loop);
apply_patterns(&range1, &mut ctx);
let range2 = UOp::range_axis(end.clone(), AxisId::Renumbered(1), AxisType::Loop);
let result = apply_patterns(&range2, &mut ctx);
assert!(result.is_none());
}
#[test]
#[ignore = "Incomplete: only tests negative case, missing spurious sources test case"]
fn test_cleanup_const_define_var() {
let mut ctx = KernelContext::new();
let define_var = UOp::new(Op::DefineVar { name: "x".to_string(), min_val: 0, max_val: 10 }, DType::Index);
let result = apply_patterns(&define_var, &mut ctx);
assert!(result.is_none());
}
#[test]
fn test_remove_zero_range_uint() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(0);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let result = apply_patterns(&range, &mut ctx);
let op = result.expect("Expected Some result");
assert!(matches!(op.op(), Op::Const(_)));
}
#[test]
fn test_remove_zero_range_non_zero() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let result = apply_patterns(&range, &mut ctx);
assert!(result.is_none());
}
#[test]
#[ignore = "MSTACK/AFTER handling not fully implemented yet"]
fn test_handle_after_mstack_advanced() {
let mut ctx = KernelContext::new();
let buf1 = UOp::buffer_id(Some(1));
let buf2 = UOp::buffer_id(Some(2));
let mstack = UOp::new(Op::MStack { buffers: smallvec::smallvec![buf1.clone(), buf2] }, DType::Float32);
let after = mstack.after(smallvec::SmallVec::new());
let result = apply_patterns(&after, &mut ctx);
match result {
Some(buf) => {
assert!(std::sync::Arc::ptr_eq(&buf, &buf1));
assert!(ctx.buffer_map.contains_key(&morok_ir::UOpKey(mstack)));
}
_ => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_cleanup_const_with_spurious_sources() {
let mut ctx = KernelContext::new();
let const_op = UOp::native_const(42i32);
let result = apply_patterns(&const_op, &mut ctx);
assert!(result.is_none());
}
#[test]
fn test_renumber_range_sequential() {
let mut ctx = KernelContext::new();
let range0 = UOp::range_axis(UOp::index_const(10), AxisId::Unrenumbered(0), AxisType::Loop);
let range1 = UOp::range_axis(UOp::index_const(20), AxisId::Unrenumbered(1), AxisType::Loop);
let range2 = UOp::range_axis(UOp::index_const(30), AxisId::Unrenumbered(2), AxisType::Reduce);
let result0 = apply_patterns(&range0, &mut ctx);
match result0 {
Some(new_range) => {
if let Op::Range { axis_id, axis_type, .. } = new_range.op() {
assert_eq!(*axis_id, AxisId::Renumbered(0));
assert_eq!(*axis_type, AxisType::Loop);
} else {
panic!("Expected RANGE operation for LOOP");
}
}
None => panic!("Expected renumbered range"),
}
let result1 = apply_patterns(&range1, &mut ctx);
match result1 {
Some(new_range) => {
if let Op::Range { axis_id, axis_type, .. } = new_range.op() {
assert_eq!(*axis_id, AxisId::Renumbered(1));
assert_eq!(*axis_type, AxisType::Loop);
} else {
panic!("Expected RANGE operation for LOOP");
}
}
None => panic!("Expected renumbered range"),
}
let result2 = apply_patterns(&range2, &mut ctx);
match result2 {
Some(new_range) => {
if let Op::Range { axis_id, axis_type, .. } = new_range.op() {
assert_eq!(*axis_id, AxisId::Renumbered(2));
assert_eq!(*axis_type, AxisType::Reduce);
} else {
panic!("Expected RANGE operation");
}
}
None => panic!("Expected renumbered range"),
}
assert_eq!(ctx.range_counter, 3);
}
#[test]
fn test_remove_zero_range_verification() {
let mut ctx = KernelContext::new();
let end = UOp::index_const(0);
let range = UOp::range_axis(end.clone(), AxisId::Renumbered(0), AxisType::Loop);
let result = apply_patterns(&range, &mut ctx);
match result {
Some(const_op) => {
if let Op::Const(val) = const_op.op() {
assert_eq!(val.0, ConstValue::Int(0));
assert!(!std::sync::Arc::ptr_eq(&const_op, &range));
assert_eq!(const_op.dtype(), DType::Index);
} else {
panic!("Expected CONST operation");
}
}
_ => panic!("Expected Rewritten result for zero range"),
}
}
#[test]
fn test_pattern_composition_sequence() {
let mut ctx = KernelContext::new();
let range_unnum = UOp::range_axis(UOp::index_const(15), AxisId::Unrenumbered(7), AxisType::Reduce);
let result1 = apply_patterns(&range_unnum, &mut ctx);
match result1 {
Some(renumbered) => {
if let Op::Range { axis_id, end, axis_type, .. } = renumbered.op() {
assert_eq!(*axis_id, AxisId::Renumbered(0));
assert_eq!(*axis_type, AxisType::Reduce);
if let Op::Range { end: original_end, .. } = range_unnum.op() {
assert!(std::sync::Arc::ptr_eq(end, original_end));
}
let result2 = apply_patterns(&renumbered, &mut ctx);
assert!(result2.is_none());
} else {
panic!("Expected RANGE operation");
}
}
None => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_pattern_composition_sequence_no_bind() {
let mut ctx = KernelContext::new();
let range_unnum = UOp::range_axis(UOp::index_const(15), AxisId::Unrenumbered(7), AxisType::Loop);
let result1 = apply_patterns(&range_unnum, &mut ctx);
match result1 {
Some(new_range) => {
if let Op::Range { axis_id, axis_type, end, .. } = new_range.op() {
assert_eq!(*axis_id, AxisId::Renumbered(0));
assert_eq!(*axis_type, AxisType::Loop);
if let Op::Range { end: original_end, .. } = range_unnum.op() {
assert!(std::sync::Arc::ptr_eq(end, original_end));
}
} else {
panic!("Expected RANGE operation for LOOP axis, got {:?}", new_range.op());
}
}
None => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_handle_after_local_buffer_not_tracked() {
let mut ctx = KernelContext::new();
let local_dtype = DType::Float32.ptr(Some(1024), AddrSpace::Local);
let local_buf = UOp::define_local(1, local_dtype);
let store = UOp::noop();
let after = local_buf.after(smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
match result {
Some(op) => {
assert!(matches!(op.op(), Op::DefineLocal(_)));
assert!(!ctx.has_buffer(&local_buf));
}
_ => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_handle_after_global_buffer_tracked() {
let mut ctx = KernelContext::new();
let global_dtype = DType::Float32.ptr(Some(1024), AddrSpace::Global);
let global_buf = UOp::param(1, 1024, global_dtype, None);
let store = UOp::noop();
let after = global_buf.after(smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
match result {
Some(op) => {
assert!(matches!(op.op(), Op::Param { device: None, .. }));
assert!(ctx.has_buffer(&global_buf));
assert!(Arc::ptr_eq(ctx.get_buffer(&global_buf).unwrap(), &after));
}
_ => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_handle_after_mstack_with_local_buffer() {
let mut ctx = KernelContext::new();
let local_dtype = DType::Float32.ptr(Some(512), AddrSpace::Local);
let local_buf1 = UOp::define_local(1, local_dtype.clone());
let local_buf2 = UOp::define_local(2, local_dtype.clone());
let mstack = UOp::new(Op::MStack { buffers: smallvec![local_buf1.clone(), local_buf2] }, local_dtype);
let store = UOp::noop();
let after = mstack.after(smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
match result {
Some(op) => {
assert!(Arc::ptr_eq(&op, &local_buf1), "Should unwrap to first buffer in MSTACK");
assert!(matches!(op.op(), Op::DefineLocal(1)));
assert!(!ctx.has_buffer(&local_buf1));
}
_ => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_handle_after_mselect_with_local_buffer() {
let mut ctx = KernelContext::new();
let local_dtype = DType::Int32.ptr(Some(256), AddrSpace::Local);
let local_buf = UOp::define_local(3, local_dtype.clone());
let mselect = UOp::new(Op::MSelect { buffer: local_buf.clone(), device_index: 0 }, local_dtype);
let store = UOp::noop();
let after = mselect.after(smallvec![store]);
let result = apply_patterns(&after, &mut ctx);
match result {
Some(op) => {
assert!(Arc::ptr_eq(&op, &local_buf), "Should unwrap to buffer from MSELECT");
assert!(matches!(op.op(), Op::DefineLocal(3)));
assert!(!ctx.has_buffer(&local_buf));
}
_ => panic!("Expected Rewritten result"),
}
}
#[test]
fn test_handle_after_mixed_address_spaces() {
let mut ctx = KernelContext::new();
let local_dtype = DType::Float32.ptr(Some(128), AddrSpace::Local);
let global_dtype = DType::Float32.ptr(Some(128), AddrSpace::Global);
let local_buf = UOp::define_local(10, local_dtype);
let global_buf = UOp::param(11, 128, global_dtype, None);
let store1 = UOp::noop();
let store2 = UOp::noop();
let after_local = local_buf.after(smallvec![store1]);
let after_global = global_buf.after(smallvec![store2]);
let result_local = apply_patterns(&after_local, &mut ctx);
let result_global = apply_patterns(&after_global, &mut ctx);
match result_local {
Some(op) => {
assert!(Arc::ptr_eq(&op, &local_buf), "Local AFTER should return local buffer");
}
_ => panic!("Expected Rewritten for local"),
}
match result_global {
Some(op) => {
assert!(Arc::ptr_eq(&op, &global_buf), "Global AFTER should return global buffer");
}
_ => panic!("Expected Rewritten for global"),
}
assert!(!ctx.has_buffer(&local_buf), "Local buffer should NOT be tracked");
assert!(ctx.has_buffer(&global_buf), "Global buffer SHOULD be tracked");
}