use morok_ir::{Op, UOp};
use crate::rangeify::{KernelContext, run_kernel_split_pipeline};
use crate::test::unit::rangeify::helpers::{count_kernels, count_stores};
#[test]
fn test_single_store_one_kernel() {
let compute = UOp::native_const(1.0f32);
let range = UOp::range_const(10, 0);
let bufferize = UOp::bufferize_global(compute, vec![range]);
let (result, _context) = run_kernel_split_pipeline(bufferize);
assert_eq!(count_kernels(&result), 1);
}
#[test]
fn test_double_store_two_kernels() {
let compute1 = UOp::native_const(1.0f32);
let compute2 = UOp::native_const(2.0f32);
let range1 = UOp::range_const(10, 0);
let range2 = UOp::range_const(20, 1);
let bufferize1 = UOp::bufferize_global(compute1, vec![range1]);
let bufferize2 = UOp::bufferize_global(compute2, vec![range2]);
let root = UOp::sink(vec![bufferize1, bufferize2]);
let (result, _context) = run_kernel_split_pipeline(root);
assert_eq!(count_kernels(&result), 2);
}
#[test]
fn test_shared_buffer_one_kernel() {
let mut ctx = KernelContext::new();
let compute = UOp::native_const(42i32);
let range = UOp::range_const(5, 0);
let bufferize = UOp::bufferize_global(compute, vec![range]);
use crate::rangeify::transforms::bufferize_to_store;
let _result1 = bufferize_to_store(&bufferize, &mut ctx, true);
let _result2 = bufferize_to_store(&bufferize, &mut ctx, true);
assert!(ctx.has_buffer(&bufferize));
let buf1 = ctx.get_buffer(&bufferize).unwrap();
let buf2 = ctx.get_buffer(&bufferize).unwrap();
assert!(std::sync::Arc::ptr_eq(buf1, buf2));
}
#[test]
fn test_independent_buffers_separate() {
let mut ctx = KernelContext::new();
let compute1 = UOp::native_const(1.0f32);
let compute2 = UOp::native_const(2.0f32);
let range = UOp::range_const(10, 0);
let bufferize1 = UOp::bufferize_global(compute1, vec![range.clone()]);
let bufferize2 = UOp::bufferize_global(compute2, vec![range]);
use crate::rangeify::transforms::bufferize_to_store;
bufferize_to_store(&bufferize1, &mut ctx, true);
bufferize_to_store(&bufferize2, &mut ctx, true);
assert!(ctx.has_buffer(&bufferize1));
assert!(ctx.has_buffer(&bufferize2));
let buf1 = ctx.get_buffer(&bufferize1).unwrap();
let buf2 = ctx.get_buffer(&bufferize2).unwrap();
assert!(!std::sync::Arc::ptr_eq(buf1, buf2));
}
#[test]
fn test_nested_end_operations() {
let store = UOp::noop();
let range1 = UOp::range_const(4, 0);
let range2 = UOp::range_const(8, 1);
let end1 = store.end(smallvec::smallvec![range1.clone()]);
let end2 = end1.clone().end(smallvec::smallvec![range2.clone()]);
if let Op::End { computation, ranges } = end2.op() {
assert_eq!(ranges.len(), 1);
assert!(std::sync::Arc::ptr_eq(&ranges[0], &range2));
assert!(std::sync::Arc::ptr_eq(computation, &end1));
if let Op::End { ranges: inner_ranges, .. } = computation.op() {
assert_eq!(inner_ranges.len(), 1);
assert!(std::sync::Arc::ptr_eq(&inner_ranges[0], &range1));
}
} else {
panic!("Expected END operation");
}
}
#[test]
fn test_pipeline_kernel_count() {
let compute = UOp::native_const(false);
let range = UOp::range_const(100, 0);
let bufferize = UOp::bufferize_global(compute, vec![range]);
let (result, _context) = run_kernel_split_pipeline(bufferize);
assert_eq!(count_kernels(&result), 1);
assert_eq!(count_stores(&result), 1, "STORE should be inside KERNEL");
}