use std::f32::consts::PI;
use morok_ir::{Op, UOp};
use crate::rangeify::kernel::run_kernel_split_pipeline;
use crate::test::unit::rangeify::helpers::{count_bufferizes, count_codegen_params, count_kernels, extract_kernel};
#[test]
fn test_pipeline_two_bufferizes() {
let compute1 = UOp::native_const(1.0f32);
let compute2 = UOp::native_const(42i32);
let range1 = UOp::range_const(10, 0);
let range2 = UOp::range_const(5, 1);
let bufferize_global = UOp::bufferize_global(compute1, vec![range1]);
let bufferize_local = UOp::bufferize_local(compute2, vec![range2]);
let root = UOp::sink(vec![bufferize_global, bufferize_local]);
let (result, _context) = run_kernel_split_pipeline(root);
let global_count = count_codegen_params(&result);
assert!(global_count >= 1, "Should have at least 1 codegen PARAM from global BUFFERIZE");
let local_bufferize_count = count_bufferizes(&result);
assert!(local_bufferize_count >= 1, "Local BUFFERIZE should remain unconverted (handled later in codegen)");
}
#[test]
fn test_pipeline_preserves_structure() {
let compute = UOp::native_const(PI);
let range = UOp::range_const(20, 0);
let bufferize = UOp::bufferize_global(compute.clone(), vec![range.clone()]);
let (result, _context) = run_kernel_split_pipeline(bufferize.clone());
let kernel = extract_kernel(&result).expect("Pipeline should create a KERNEL");
if let Op::Kernel { ast, sources } = kernel.op() {
assert!(matches!(ast.op(), Op::Sink { .. }));
assert!(!sources.is_empty(), "KERNEL should have sources");
assert!(sources.iter().any(|s| matches!(s.op(), Op::Buffer { .. })), "KERNEL sources should include BUFFER");
} else {
panic!("Expected KERNEL operation, got {:?}", kernel.op());
}
}
#[test]
fn test_pipeline_context_threading() {
let compute = UOp::native_const(true);
let range = UOp::range_const(8, 0);
let bufferize = UOp::bufferize_global(compute, vec![range]);
let (result, _context) = run_kernel_split_pipeline(bufferize);
let kernel = extract_kernel(&result).expect("Pipeline should create a KERNEL");
if let Op::Kernel { sources, .. } = kernel.op() {
assert_eq!(sources.len(), 1, "KERNEL should have 1 source (the buffer from stage 1)");
assert!(matches!(sources[0].op(), Op::Buffer { .. }), "Source should be BUFFER, got {:?}", sources[0].op());
} else {
panic!("Expected KERNEL operation, got {:?}", kernel.op());
}
}
#[test]
fn test_pipeline_mixed_addrspace() {
let global_compute = UOp::native_const(1.0f32);
let local_compute = UOp::native_const(2.0f32);
let range = UOp::range_const(16, 0);
let global_buf = UOp::bufferize_global(global_compute, vec![range.clone()]);
let local_buf = UOp::bufferize_local(local_compute, vec![range]);
let root = UOp::sink(vec![global_buf, local_buf]);
let (result, _context) = run_kernel_split_pipeline(root);
let globals = count_codegen_params(&result);
assert!(globals >= 1, "Should have at least 1 codegen PARAM from global BUFFERIZE");
let local_bufferizes = count_bufferizes(&result);
assert!(local_bufferizes >= 1, "Local BUFFERIZE should remain for later codegen conversion");
}
#[test]
fn test_pipeline_reshape_buffer_to_load() {
use morok_dtype::DType;
let input_buffer = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 12, DType::Float32);
let reshape_shape = UOp::vectorize(vec![UOp::index_const(3), UOp::index_const(4)].into());
let reshaped = UOp::new(Op::Reshape { src: input_buffer, new_shape: reshape_shape }, DType::Float32);
let input_buffer2 = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 12, DType::Float32);
let reshape_shape2 = UOp::vectorize(vec![UOp::index_const(3), UOp::index_const(4)].into());
let reshaped2 = UOp::new(Op::Reshape { src: input_buffer2, new_shape: reshape_shape2 }, DType::Float32);
let compute = reshaped.try_add(&reshaped2).expect("Add should work");
let contiguous = compute.contiguous();
let sink = UOp::sink(vec![contiguous]);
let (rangeified, _ctx) = crate::rangeify::rangeify(sink, None).expect("Rangeify should succeed");
let has_reshape = rangeified.toposort().iter().any(|node| matches!(node.op(), Op::Reshape { .. }));
assert!(!has_reshape, "RESHAPE should be eliminated after rangeify");
}
#[test]
fn test_full_pipeline_creates_load_for_input_buffers() {
use morok_dtype::DType;
let input_buffer = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 12, DType::Float32);
let reshape_shape = UOp::vectorize(vec![UOp::index_const(3), UOp::index_const(4)].into());
let reshaped = UOp::new(Op::Reshape { src: input_buffer, new_shape: reshape_shape }, DType::Float32);
let input_buffer2 = UOp::new_buffer(morok_device::DeviceSpec::Cpu, 12, DType::Float32);
let reshape_shape2 = UOp::vectorize(vec![UOp::index_const(3), UOp::index_const(4)].into());
let reshaped2 = UOp::new(Op::Reshape { src: input_buffer2, new_shape: reshape_shape2 }, DType::Float32);
let compute = reshaped.try_add(&reshaped2).expect("Add should work");
let contiguous = compute.contiguous();
let sink = UOp::sink(vec![contiguous]);
let (rangeified, _ctx) = crate::rangeify::rangeify(sink.clone(), None).expect("Rangeify should succeed");
println!("After rangeify:");
for node in rangeified.toposort() {
let op_name = match node.op() {
Op::Bufferize { .. } => "BUFFERIZE",
Op::Buffer { .. } => "BUFFER",
Op::Index { .. } => "INDEX",
Op::Binary(morok_ir::BinaryOp::Add, _, _) => "ADD",
Op::Reshape { .. } => "RESHAPE",
Op::Sink { .. } => "SINK",
_ => continue,
};
println!(" {} [{:?}]", op_name, node.dtype());
}
println!("Rangeified root op: {:?}", rangeified.op());
let (kernelized, _context) = run_kernel_split_pipeline(rangeified);
println!("Kernelized graph ops:");
for node in kernelized.toposort() {
let op_name = match node.op() {
Op::Kernel { .. } => "KERNEL",
Op::Sink { .. } => "SINK",
Op::Store { .. } => "STORE",
Op::Load { .. } => "LOAD",
Op::Index { buffer, .. } => {
let buf_name = match buffer.op() {
Op::Buffer { .. } => "BUFFER",
Op::Param { device: None, .. } => "PARAM",
Op::DefineLocal(_) => "DEFINE_LOCAL",
_ => "OTHER",
};
println!(" INDEX(buf={}) [{:?}]", buf_name, node.dtype());
continue;
}
Op::Binary(op, _, _) => match op {
morok_ir::BinaryOp::Add => "ADD",
_ => "BINARY",
},
Op::Param { slot, device: None, .. } => {
println!(" PARAM({})", slot);
continue;
}
Op::DefineLocal(id) => {
println!(" DEFINE_LOCAL({})", id);
continue;
}
Op::Buffer { .. } => "BUFFER",
Op::Range { .. } => "RANGE",
Op::Const(_) => "CONST",
Op::End { .. } => "END",
_ => "OTHER",
};
println!(" {} [{:?}]", op_name, node.dtype());
}
let topo = kernelized.toposort();
let index_on_buffer = topo
.iter()
.filter(|node| {
if let Op::Index { buffer, .. } = node.op() {
matches!(buffer.op(), Op::Buffer { .. } | Op::Param { device: None, .. })
} else {
false
}
})
.collect::<Vec<_>>();
assert!(!index_on_buffer.is_empty(), "Pipeline should create INDEX operations for input buffers");
}
#[test]
#[ignore = "Pipeline doesn't handle complex chaining yet"]
fn test_pipeline_chained_operations() {
let compute_a = UOp::native_const(1i32);
let range_a = UOp::range_const(10, 0);
let buf_a = UOp::bufferize_global(compute_a.clone(), vec![range_a]);
let compute_b = buf_a; let range_b = UOp::range_const(5, 1);
let buf_b = UOp::bufferize_global(compute_b, vec![range_b]);
let (result, _context) = run_kernel_split_pipeline(buf_b);
let kernel_count = count_kernels(&result);
assert!(kernel_count >= 1, "Should create at least 1 kernel");
let buffer_count = count_codegen_params(&result);
assert!(buffer_count >= 1, "Should create at least 1 buffer");
}