use std::sync::Arc;
use morok_ir::{Op, UOp};
use crate::rangeify::kernel::split_store;
use crate::rangeify::{KernelContext, bufferize_to_store};
#[allow(unused_imports)]
use crate::test::unit::rangeify::helpers::extract_kernel;
fn call_split_store(x: &Arc<UOp>) -> Option<Arc<UOp>> {
let mut uop_list = Vec::new();
split_store(&mut uop_list, x)
}
#[test]
fn test_bufferize_to_store_global() {
let mut ctx = KernelContext::new();
let compute = UOp::native_const(42.0f32);
let range = UOp::range_const(10, 0);
let bufferize = UOp::bufferize_global(compute.clone(), vec![range.clone()]);
let result = bufferize_to_store(&bufferize, &mut ctx, true);
assert!(result.is_some());
let result = result.unwrap();
let Op::After { passthrough, deps } = result.op() else {
panic!("Expected AFTER operation, got {:?}", result.op());
};
assert!(matches!(passthrough.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", passthrough.op());
assert_eq!(deps.len(), 1);
let Op::End { computation, ranges: end_ranges } = deps[0].op() else {
panic!("Expected END operation in deps, got {:?}", deps[0].op());
};
assert_eq!(end_ranges.len(), 1);
assert!(std::sync::Arc::ptr_eq(&end_ranges[0], &range));
let Op::Store { index, value, ranges } = computation.op() else {
panic!("Expected STORE operation inside END, got {:?}", computation.op());
};
assert!(ranges.is_empty());
let Op::Index { buffer, .. } = index.op() else {
panic!("Expected INDEX operation, got {:?}", index.op());
};
assert!(matches!(buffer.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", buffer.op());
assert!(std::sync::Arc::ptr_eq(buffer, passthrough));
assert!(std::sync::Arc::ptr_eq(value, &compute));
assert!(matches!(index.op(), Op::Index { .. }));
assert_eq!(ctx.local_counter, 0);
assert!(ctx.has_buffer(&bufferize));
}
#[test]
fn test_bufferize_to_store_local_with_barrier() {
let mut ctx = KernelContext::new();
let compute = UOp::native_const(1.0f32);
let range = UOp::range_const(5, 0);
let bufferize = UOp::bufferize_local(compute.clone(), vec![range.clone()]);
let result = bufferize_to_store(&bufferize, &mut ctx, true);
assert!(result.is_some());
let result = result.unwrap();
let Op::After { passthrough, deps } = result.op() else {
panic!("Expected AFTER operation, got {:?}", result.op());
};
assert!(matches!(passthrough.op(), Op::DefineLocal(0)));
assert_eq!(deps.len(), 1);
let Op::Barrier { src, .. } = deps[0].op() else {
panic!("Expected BARRIER operation in deps");
};
let Op::End { computation, ranges: end_ranges } = src.op() else {
panic!("Expected END operation inside BARRIER, got {:?}", src.op());
};
assert_eq!(end_ranges.len(), 1);
assert!(std::sync::Arc::ptr_eq(&end_ranges[0], &range));
let Op::Store { index, value, ranges } = computation.op() else {
panic!("Expected STORE operation inside END, got {:?}", computation.op());
};
assert!(ranges.is_empty());
let Op::Index { buffer, .. } = index.op() else {
panic!("Expected INDEX operation, got {:?}", index.op());
};
assert!(matches!(buffer.op(), Op::DefineLocal(0)));
assert!(std::sync::Arc::ptr_eq(value, &compute));
assert_eq!(ctx.global_counter, 0);
assert_eq!(ctx.local_counter, 1);
assert!(ctx.has_buffer(&bufferize));
}
#[test]
#[should_panic(expected = "unexpected multi-range")]
fn test_bufferize_to_store_multiple_ranges() {
let mut ctx = KernelContext::new();
let compute = UOp::native_const(100i32);
let range1 = UOp::range_const(4, 0);
let range2 = UOp::range_const(8, 1);
let bufferize = UOp::bufferize_global(compute.clone(), vec![range1.clone(), range2.clone()]);
let result = bufferize_to_store(&bufferize, &mut ctx, true);
assert!(result.is_some());
let result = result.unwrap();
let Op::After { passthrough, deps } = result.op() else {
panic!("Expected AFTER operation, got {:?}", result.op());
};
assert!(matches!(passthrough.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", passthrough.op());
assert_eq!(deps.len(), 1);
let Op::End { computation, ranges: end_ranges } = deps[0].op() else {
panic!("Expected END operation in deps, got {:?}", deps[0].op());
};
assert_eq!(end_ranges.len(), 2);
assert!(std::sync::Arc::ptr_eq(&end_ranges[0], &range1));
assert!(std::sync::Arc::ptr_eq(&end_ranges[1], &range2));
let Op::Store { index, value, ranges } = computation.op() else {
panic!("Expected STORE operation inside END, got {:?}", computation.op());
};
assert!(ranges.is_empty());
assert!(std::sync::Arc::ptr_eq(value, &compute));
let Op::Index { buffer: idx_buffer, indices, .. } = index.op() else {
panic!("Expected INDEX operation");
};
assert!(matches!(idx_buffer.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", idx_buffer.op());
assert!(std::sync::Arc::ptr_eq(idx_buffer, passthrough));
assert_eq!(indices.len(), 1, "Multi-index should be linearized to single index");
assert!(ctx.has_buffer(&bufferize));
}
#[test]
fn test_non_bufferize_returns_none() {
let mut ctx = KernelContext::new();
let const_op = UOp::native_const(1.0f32);
let result = bufferize_to_store(&const_op, &mut ctx, true);
assert!(result.is_none());
}
#[test]
fn test_buffer_tracked_in_context() {
let mut ctx = KernelContext::new();
let compute = UOp::native_const(1.0f32);
let bufferize = UOp::bufferize_global(compute, vec![]);
assert!(!ctx.has_buffer(&bufferize));
bufferize_to_store(&bufferize, &mut ctx, true);
assert!(ctx.has_buffer(&bufferize));
let replacement = ctx.get_buffer(&bufferize).unwrap();
let Op::After { passthrough, .. } = replacement.op() else {
panic!("Expected AFTER operation, got {:?}", replacement.op());
};
assert!(matches!(passthrough.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", passthrough.op());
}
#[test]
fn test_bufferize_to_store_sequential_global_ids() {
let mut ctx = KernelContext::new();
for i in 0..3 {
let compute = UOp::native_const((i as f64) as f32);
let bufferize = UOp::bufferize_global(compute, vec![]);
let result = bufferize_to_store(&bufferize, &mut ctx, true);
assert!(result.is_some());
assert!(ctx.has_buffer(&bufferize));
assert_eq!(ctx.local_counter, 0);
}
}
#[test]
fn test_bufferize_to_store_sequential_local_ids() {
let mut ctx = KernelContext::new();
for i in 0..3 {
let compute = UOp::native_const((i as f64) as f32);
let bufferize = UOp::bufferize_local(compute, vec![]);
bufferize_to_store(&bufferize, &mut ctx, true);
assert_eq!(ctx.global_counter, 0);
assert_eq!(ctx.local_counter, (i + 1) as usize);
}
}
#[test]
fn test_bufferize_to_store_mixed_global_local() {
let mut ctx = KernelContext::new();
let global_compute = UOp::native_const(1.0f32);
let local_compute = UOp::native_const(2.0f32);
let global_bufferize = UOp::bufferize_global(global_compute.clone(), vec![]);
let local_bufferize = UOp::bufferize_local(local_compute.clone(), vec![]);
let global_result = bufferize_to_store(&global_bufferize, &mut ctx, true);
let local_result = bufferize_to_store(&local_bufferize, &mut ctx, true);
assert_eq!(ctx.local_counter, 1);
let global_result = global_result.unwrap();
let Op::After { passthrough: global_buf, deps: global_deps } = global_result.op() else {
panic!("Expected AFTER operation for global, got {:?}", global_result.op());
};
assert!(matches!(global_buf.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", global_buf.op());
assert_eq!(global_deps.len(), 1);
let Op::Store { index, value, .. } = global_deps[0].op() else {
panic!("Expected STORE in global AFTER deps, got {:?}", global_deps[0].op());
};
let Op::Index { buffer, .. } = index.op() else {
panic!("Expected INDEX operation, got {:?}", index.op());
};
assert!(std::sync::Arc::ptr_eq(buffer, global_buf));
assert!(std::sync::Arc::ptr_eq(value, &global_compute));
let local_result = local_result.unwrap();
let Op::After { passthrough: local_buf, deps: local_deps } = local_result.op() else {
panic!("Expected AFTER operation for local, got {:?}", local_result.op());
};
assert!(matches!(local_buf.op(), Op::DefineLocal(0)));
assert_eq!(local_deps.len(), 1);
let Op::Barrier { src, .. } = local_deps[0].op() else {
panic!("Expected BARRIER in local AFTER deps, got {:?}", local_deps[0].op());
};
let Op::Store { index, value, .. } = src.op() else {
panic!("Expected STORE inside BARRIER, got {:?}", src.op());
};
let Op::Index { buffer, .. } = index.op() else {
panic!("Expected INDEX operation, got {:?}", index.op());
};
assert!(std::sync::Arc::ptr_eq(buffer, local_buf));
assert!(std::sync::Arc::ptr_eq(value, &local_compute));
assert!(ctx.has_buffer(&global_bufferize));
assert!(ctx.has_buffer(&local_bufferize));
}
#[test]
fn test_bufferize_to_store_integration_with_split_kernel() {
let mut ctx = KernelContext::new();
let compute = UOp::native_const(42.0f32);
let range = UOp::range_const(10, 0);
let bufferize = UOp::bufferize_global(compute.clone(), vec![range]);
let store_result = bufferize_to_store(&bufferize, &mut ctx, true).unwrap();
assert!(ctx.has_buffer(&bufferize));
let Op::After { passthrough: buffer_node, deps } = store_result.op() else {
panic!("Expected AFTER operation, got {:?}", store_result.op());
};
assert!(matches!(buffer_node.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", buffer_node.op());
assert_eq!(deps.len(), 1);
let end_op = &deps[0];
let Op::End { computation, ranges: end_ranges } = end_op.op() else {
panic!("Expected END in AFTER deps, got {:?}", end_op.op());
};
assert_eq!(end_ranges.len(), 1);
assert!(matches!(computation.op(), Op::Store { .. }), "Expected STORE inside END");
let kernel = call_split_store(end_op).expect("split_store should create a KERNEL");
let Op::Kernel { sources, ast } = kernel.op() else {
panic!("Expected KERNEL operation, got {:?}", kernel.op());
};
assert!(!sources.is_empty(), "KERNEL should have at least one source");
let Op::Sink { sources: sink_sources } = ast.op() else {
panic!("Expected SINK operation in kernel AST, got {:?}", ast.op());
};
assert_eq!(sink_sources.len(), 1, "SINK should have 1 source");
let ctx_buffer = ctx.get_buffer(&bufferize).unwrap();
let Op::After { passthrough, .. } = ctx_buffer.op() else {
panic!("Expected AFTER in context buffer mapping");
};
assert!(matches!(passthrough.op(), Op::Buffer { .. }), "Expected BUFFER, got {:?}", passthrough.op());
}