use std::sync::Arc;
use morok_dtype::{DType, DeviceSpec};
use morok_ir::{AxisId, AxisType, ConstValue, Op, UOp};
use smallvec::smallvec;
use crate::rangeify::kernel::split_store;
fn call_split_store(x: &Arc<UOp>) -> Option<Arc<UOp>> {
let mut uop_list = Vec::new();
split_store(&mut uop_list, x)
}
#[test]
fn test_split_store_basic() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let result = call_split_store(&store);
assert!(result.is_some());
let kernel = result.unwrap();
assert!(matches!(kernel.op(), Op::Kernel { .. }));
if let Op::Kernel { sources, .. } = kernel.op() {
assert!(!sources.is_empty(), "Kernel sources should contain the buffer");
}
}
#[test]
fn test_split_store_non_store_returns_none() {
let const_op = UOp::native_const(1.0f32);
let result = call_split_store(&const_op);
assert!(result.is_none());
}
#[test]
fn test_split_store_end_operation() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let range = UOp::range_const(10, 0);
let end = store.clone().end(smallvec![range.clone()]);
let result = call_split_store(&end);
assert!(result.is_some());
let kernel = result.unwrap();
assert!(matches!(kernel.op(), Op::Kernel { .. }));
if let Op::Kernel { sources, ast } = kernel.op() {
assert!(!sources.is_empty(), "Kernel sources should contain buffer mappings");
if let Op::Sink { sources: sink_sources } = ast.op() {
assert_eq!(sink_sources.len(), 1);
if let Op::End { ranges, .. } = sink_sources[0].op() {
assert_eq!(ranges.len(), 1);
} else {
panic!("Expected END operation in SINK");
}
} else {
panic!("Expected SINK operation");
}
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_end_non_store_returns_none() {
let noop = UOp::noop();
let range = UOp::range_const(10, 0);
let end = noop.end(smallvec![range]);
let result = call_split_store(&end);
assert!(result.is_none());
}
#[test]
fn test_split_store_creates_sink() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value.clone());
let result = call_split_store(&store).unwrap();
if let Op::Kernel { sources, ast } = result.op() {
if let Op::Sink { sources: sink_sources } = ast.op() {
assert_eq!(sink_sources.len(), 1);
if let Op::Store { index: store_index, value: store_val, .. } = sink_sources[0].op() {
let Op::Index { buffer: store_buf, .. } = store_index.op() else {
panic!("Expected INDEX operation in STORE, got {:?}", store_index.op());
};
assert!(
matches!(store_buf.op(), Op::Param { device: None, .. }),
"Expected codegen PARAM, got {:?}",
store_buf.op()
);
assert!(std::sync::Arc::ptr_eq(store_val, &value));
} else {
panic!("Expected STORE in SINK sources");
}
} else {
panic!("Expected SINK operation");
}
assert!(!sources.is_empty(), "Kernel sources should contain buffer mappings");
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_preserves_computation() {
let test_cases = [
(DType::Float32, ConstValue::Float(1.0)),
(DType::Int32, ConstValue::Int(1)),
(DType::Bool, ConstValue::Bool(true)),
];
for (_dtype_idx, (dtype, _const_val)) in test_cases.iter().enumerate() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, dtype.clone());
let const_idx = UOp::index_const(0);
let value = match _dtype_idx {
0 => UOp::native_const(1.0f32),
1 => UOp::native_const(1i32),
2 => UOp::native_const(true),
_ => panic!("Unsupported dtype index"),
};
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value.clone());
let result = call_split_store(&store);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { ast, .. } = kernel.op()
&& let Op::Sink { sources } = ast.op()
{
if let Op::Store { value: stored_val, .. } = sources[0].op() {
assert_eq!(stored_val.dtype(), *dtype);
assert!(std::sync::Arc::ptr_eq(stored_val, &value));
}
}
}
}
#[test]
fn test_split_store_multiple_calls_independent() {
let buffer1 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let idx_offset1 = UOp::index_const(0);
let value1 = UOp::native_const(1.0f32);
let index1 = UOp::index().buffer(buffer1).indices(vec![idx_offset1]).call().unwrap();
let store1 = index1.store(value1);
let buffer2 = UOp::new_buffer(DeviceSpec::Cpu, 200, DType::Float32);
let idx_offset2 = UOp::index_const(0);
let value2 = UOp::native_const(2.0f32);
let index2 = UOp::index().buffer(buffer2).indices(vec![idx_offset2]).call().unwrap();
let store2 = index2.store(value2);
let kernel1 = call_split_store(&store1).unwrap();
let kernel2 = call_split_store(&store2).unwrap();
assert!(matches!(kernel1.op(), Op::Kernel { .. }));
assert!(matches!(kernel2.op(), Op::Kernel { .. }));
assert!(!std::sync::Arc::ptr_eq(&kernel1, &kernel2));
}
#[test]
fn test_split_store_end_with_multiple_ranges() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let range1 = UOp::range_const(4, 0);
let range2 = UOp::range_const(8, 1);
let end = store.clone().end(smallvec![range1.clone(), range2.clone()]);
let result = call_split_store(&end);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { sources, ast } = kernel.op() {
assert!(!sources.is_empty(), "Kernel sources should contain buffer mappings");
if let Op::Sink { sources: sink_sources } = ast.op() {
assert_eq!(sink_sources.len(), 1);
if let Op::End { ranges, .. } = sink_sources[0].op() {
assert_eq!(ranges.len(), 2);
} else {
panic!("Expected END operation in SINK");
}
} else {
panic!("Expected SINK operation");
}
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_end_with_outer_range() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let range_outer = UOp::range_axis(UOp::index_const(10), AxisId::Renumbered(0), AxisType::Outer);
let end = store.end(smallvec![range_outer]);
let result = call_split_store(&end);
assert!(result.is_none());
}
#[test]
fn test_split_store_end_with_mixed_ranges() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
let range_loop = UOp::range_const(4, 0);
let range_outer = UOp::range_axis(UOp::index_const(8), AxisId::Renumbered(1), AxisType::Outer);
let end = store.end(smallvec![range_loop.clone(), range_outer.clone()]);
let result = call_split_store(&end);
assert!(result.is_some(), "first range is LOOP, should create kernel");
let end = store.end(smallvec![range_outer, range_loop]);
let result = call_split_store(&end);
assert!(result.is_none(), "first range is OUTER, should skip");
}
#[test]
fn test_split_store_with_copy() {
let src_buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let copy = src_buffer.copy_to_device(DeviceSpec::Cuda { device_id: 0 });
let output_buffer = UOp::new_buffer(DeviceSpec::Cuda { device_id: 0 }, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let store_idx = UOp::index().buffer(output_buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(copy.clone());
let result = call_split_store(&store);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { ast, .. } = kernel.op() {
assert!(matches!(ast.op(), Op::Copy { .. }), "Expected COPY operation as kernel AST, got: {:?}", ast.op());
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_with_buffer_view() {
let base_buffer = UOp::new_buffer(DeviceSpec::Cpu, 512, DType::Float32);
let buffer_view = base_buffer.view(256, 128);
let output_buffer = UOp::new_buffer(DeviceSpec::Cpu, 256, DType::Float32);
let const_idx = UOp::index_const(0);
let store_idx = UOp::index().buffer(output_buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(buffer_view.clone());
let result = call_split_store(&store);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { ast, .. } = kernel.op() {
if let Op::BufferView { size, offset, .. } = ast.op() {
assert_eq!(*size, 256);
assert_eq!(*offset, 128);
} else {
panic!("Expected BUFFER_VIEW operation as kernel AST, got: {:?}", ast.op());
}
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_normal_computation_uses_sink() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let value = a.try_add(&b).unwrap();
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let store_idx = UOp::index().buffer(buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value.clone());
let result = call_split_store(&store);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { sources, ast } = kernel.op() {
assert!(!sources.is_empty(), "Kernel sources should contain buffer mappings");
if let Op::Sink { sources: sink_sources } = ast.op() {
assert_eq!(sink_sources.len(), 1);
} else {
panic!("Expected SINK operation for normal computation, got: {:?}", ast.op());
}
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_nested_copy_in_store() {
let src_buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let copy = src_buffer.copy_to_device(DeviceSpec::Cuda { device_id: 0 });
let output_buffer = UOp::new_buffer(DeviceSpec::Cuda { device_id: 0 }, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let store_idx = UOp::index().buffer(output_buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(copy.clone());
let range = UOp::range_const(10, 0);
let end = store.end(smallvec![range]);
let result = call_split_store(&end);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { ast, .. } = kernel.op() {
assert!(
matches!(ast.op(), Op::Copy { .. }),
"Expected COPY operation as kernel AST even when nested, got: {:?}",
ast.op()
);
} else {
panic!("Expected KERNEL operation");
}
}
#[test]
fn test_split_store_copy_precedence_documented() {
let base_buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let copy1 = base_buffer.copy_to_device(DeviceSpec::Cuda { device_id: 0 });
let copy2 = copy1.clone().copy_to_device(DeviceSpec::Cpu);
let output_buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let store_idx = UOp::index().buffer(output_buffer).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(copy2.clone());
let result = call_split_store(&store);
assert!(result.is_some());
let kernel = result.unwrap();
if let Op::Kernel { ast, .. } = kernel.op() {
assert!(matches!(ast.op(), Op::Copy { .. }), "Expected COPY operation as kernel AST, got: {:?}", ast.op());
} else {
panic!("Expected KERNEL operation");
}
}