use std::sync::Arc;
use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{AxisId, AxisType, Op, UOp};
use crate::rangeify::kernel::{LocalAddBufferContext, split_store};
use crate::rangeify::patterns::rangeify_codegen_patterns;
use crate::rangeify::transforms::find_bufs;
fn call_split_store(x: &Arc<UOp>) -> Option<Arc<UOp>> {
let mut uop_list = Vec::new();
split_store(&mut uop_list, x)
}
fn apply_codegen_patterns(uop: &Arc<UOp>) -> Option<Arc<UOp>> {
let matcher = rangeify_codegen_patterns();
let mut ctx = LocalAddBufferContext::new();
match matcher.rewrite(uop, &mut ctx) {
morok_ir::pattern::RewriteResult::Rewritten(result) => Some(result),
_ => None,
}
}
#[test]
fn test_remove_noop_in_pipeline() {
let noop = UOp::noop();
let result = apply_codegen_patterns(&noop);
assert!(result.is_none());
}
#[test]
fn test_get_contiguous_in_pipeline() {
let value = UOp::native_const(42.0f32);
let contiguous = value.contiguous();
let result = apply_codegen_patterns(&contiguous);
assert!(result.is_some());
let unwrapped = result.unwrap();
assert!(Arc::ptr_eq(&unwrapped, &value));
}
#[test]
fn test_no_cycle_valid_access_pattern() {
let in_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let load_idx = UOp::index().buffer(in_buf.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let loaded = UOp::load().buffer(in_buf.clone()).index(load_idx).call();
let const_val = UOp::native_const(2.0f32);
let computed = loaded.try_mul(&const_val).unwrap();
let store_idx = UOp::index().buffer(out_buf.clone()).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(computed);
#[allow(clippy::mutable_key_type)]
let buf_accesses = find_bufs(&store);
assert_eq!(buf_accesses.len(), 2);
}
#[test]
fn test_split_store_simple_kernel() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let result = call_split_store(&store);
if let Some(kernel) = result {
assert!(matches!(kernel.op(), Op::Kernel { .. }));
}
}
#[test]
fn test_split_store_with_loop_ranges() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let range_end = UOp::index_const(10);
let loop_range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let end = store.end(vec![loop_range].into());
let result = call_split_store(&end);
if let Some(kernel) = result {
assert!(matches!(kernel.op(), Op::Kernel { .. }));
}
}
#[test]
fn test_pattern_application_order() {
let value = UOp::native_const(1.0f32);
let contiguous = value.contiguous();
let result = apply_codegen_patterns(&contiguous);
assert!(result.is_some());
}
#[test]
fn test_multiple_buffer_integration() {
let buf1 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let buf2 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let load1_idx = UOp::index().buffer(buf1.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let load1 = UOp::load().buffer(buf1.clone()).index(load1_idx).call();
let load2_idx = UOp::index().buffer(buf2.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let load2 = UOp::load().buffer(buf2.clone()).index(load2_idx).call();
let sum = load1.try_add(&load2).unwrap();
let store_idx = UOp::index().buffer(out_buf.clone()).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(sum);
#[allow(clippy::mutable_key_type)]
let buf_accesses = find_bufs(&store);
assert_eq!(buf_accesses.len(), 3);
}
#[test]
fn test_end_store_structure() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let range_end = UOp::index_const(10);
let range = UOp::range_axis(range_end, AxisId::Renumbered(0), AxisType::Loop);
let end = store.end(vec![range].into());
if let Op::End { computation, .. } = end.op() {
assert!(matches!(computation.op(), Op::Store { .. }));
} else {
panic!("Expected END operation");
}
let result = call_split_store(&end);
if let Some(kernel) = result {
assert!(matches!(kernel.op(), Op::Kernel { .. }));
}
}