use std::sync::Arc;
use morok_dtype::{AddrSpace, DType};
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use super::helpers::*;
#[test]
fn test_distribute_ptrcat_load_dual() {
let buffer = create_buffer(64);
let idx1 = create_index(buffer.clone(), 0);
let idx2 = create_index(buffer.clone(), 1);
let ptrcat = UOp::ptrcat().sources(vec![idx1, idx2]).call();
let load = UOp::load().buffer(buffer.clone()).index(ptrcat).call();
let result = apply_devectorize(&load);
match result.op() {
Op::Cat { sources } => {
assert_eq!(sources.len(), 2, "Should have 2 LOAD sources");
for src in sources.iter() {
assert_is_load(src);
}
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 2, "Should have 2 LOAD elements");
}
other => panic!("Expected CAT or VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_distribute_ptrcat_load_quad() {
let buffer = create_buffer(64);
let indices: Vec<Arc<UOp>> = (0..4).map(|i| create_index(buffer.clone(), i)).collect();
let ptrcat = UOp::ptrcat().sources(indices).call();
let load = UOp::load().buffer(buffer.clone()).index(ptrcat).call();
let result = apply_devectorize(&load);
match result.op() {
Op::Cat { sources } => {
assert_eq!(sources.len(), 4, "Should have 4 LOAD sources");
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 4, "Should have 4 LOAD elements");
}
other => panic!("Expected CAT or VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_distribute_ptrcat_preserves_buffer() {
let buffer = create_buffer(64);
let idx1 = create_index(buffer.clone(), 0);
let idx2 = create_index(buffer.clone(), 1);
let ptrcat = UOp::ptrcat().sources(vec![idx1, idx2]).call();
let load = UOp::load().buffer(buffer.clone()).index(ptrcat).call();
let result = apply_devectorize(&load);
let buffer_refs =
count_ops(&result, |u| if let Op::Load { buffer: b, .. } = u.op() { Arc::ptr_eq(b, &buffer) } else { false });
assert!(buffer_refs >= 2, "Should have at least 2 buffer references");
}
#[test]
fn test_distribute_ptrcat_store() {
let buffer = create_buffer(64);
let value = create_vector_float_iota(2);
let idx1 = create_index(buffer.clone(), 0);
let idx2 = create_index(buffer.clone(), 1);
let ptrcat = UOp::ptrcat().sources(vec![idx1, idx2]).call();
let store = ptrcat.store(value);
let result = apply_devectorize(&store);
match result.op() {
Op::Group { sources } => {
assert_eq!(sources.len(), 2, "Should have 2 STORE sources");
for src in sources.iter() {
assert_is_store(src);
}
}
Op::Store { .. } => {}
other => panic!("Expected GROUP or STORE, got {:?}", other),
}
}
#[test]
fn test_distribute_ptrcat_store_quad() {
let buffer = create_buffer(64);
let value = create_vector_float_iota(4);
let indices: Vec<Arc<UOp>> = (0..4).map(|i| create_index(buffer.clone(), i)).collect();
let ptrcat = UOp::ptrcat().sources(indices).call();
let store = ptrcat.store(value);
let result = apply_devectorize(&store);
match result.op() {
Op::Group { sources } => {
assert_eq!(sources.len(), 4, "Should have 4 STORE sources");
}
Op::Store { .. } => {}
other => panic!("Expected GROUP or STORE, got {:?}", other),
}
}
#[test]
fn test_split_load_vec8_to_vec4() {
let buffer = create_buffer(128);
let idx = create_index(buffer.clone(), 0);
let vec8_ptr_dtype = DType::Float32.vec(8).ptr(Some(8), AddrSpace::Global);
let cast_idx = idx.cast(vec8_ptr_dtype);
let load_dtype = DType::Float32.vec(8);
let load = UOp::load().buffer(buffer.clone()).index(cast_idx).dtype(load_dtype).call();
let result = apply_devectorize(&load);
match result.op() {
Op::Cat { sources } => {
assert_eq!(sources.len(), 2, "vec8 should split into 2 chunks");
for src in sources.iter() {
assert_is_load(src);
assert_eq!(src.dtype().vcount(), 4, "Each chunk should be vec4");
}
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 8);
}
Op::Load { .. } => {
assert_eq!(result.dtype().vcount(), 8);
}
other => panic!("Expected CAT, VECTORIZE or LOAD, got {:?}", other),
}
}
#[test]
fn test_split_load_vec6_mixed() {
let buffer = create_buffer(128);
let idx = create_index(buffer.clone(), 0);
let vec6_ptr_dtype = DType::Float32.vec(6).ptr(Some(6), AddrSpace::Global);
let cast_idx = idx.cast(vec6_ptr_dtype);
let load_dtype = DType::Float32.vec(6);
let load = UOp::load().buffer(buffer.clone()).index(cast_idx).dtype(load_dtype).call();
let result = apply_devectorize(&load);
match result.op() {
Op::Cat { sources } => {
assert_eq!(sources.len(), 2, "vec6 should split into 2 chunks");
let vcounts: Vec<usize> = sources.iter().map(|s| s.dtype().vcount()).collect();
assert!(vcounts == vec![4, 2] || vcounts.iter().sum::<usize>() == 6, "Chunks should sum to 6");
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 6);
}
Op::Load { .. } => {}
other => panic!("Expected CAT, VECTORIZE or LOAD, got {:?}", other),
}
}
#[test]
fn test_split_store_vec8() {
let buffer = create_buffer(128);
let value = create_vector_float_iota(8);
let idx = create_index(buffer.clone(), 0);
let vec8_ptr_dtype = DType::Float32.vec(8).ptr(Some(8), AddrSpace::Global);
let cast_idx = idx.cast(vec8_ptr_dtype);
let store = cast_idx.store(value);
let result = apply_devectorize(&store);
match result.op() {
Op::Group { sources } => {
assert_eq!(sources.len(), 2, "vec8 store should split into 2 chunks");
for src in sources.iter() {
assert_is_store(src);
}
}
Op::Store { .. } => {}
other => panic!("Expected GROUP or STORE, got {:?}", other),
}
}
#[test]
fn test_split_preserves_ranges() {
use morok_ir::AxisId;
use smallvec::smallvec;
let buffer = create_buffer(128);
let value = create_vector_float_iota(8);
let idx = create_index(buffer.clone(), 0);
let vec8_ptr_dtype = DType::Float32.vec(8).ptr(Some(8), AddrSpace::Global);
let cast_idx = idx.cast(vec8_ptr_dtype);
let range = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(10)),
axis_id: AxisId::Renumbered(0),
axis_type: morok_ir::AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let store = cast_idx.store_with_ranges(value, smallvec![range.clone()]);
let result = apply_devectorize(&store);
match result.op() {
Op::Group { sources } => {
assert!(!sources.is_empty(), "Should have split stores");
for src in sources.iter() {
if let Op::Store { ranges, .. } = src.op() {
assert_eq!(ranges.len(), 1, "Each split store should preserve ranges");
}
}
}
Op::Store { ranges, .. } => {
assert_eq!(ranges.len(), 1, "Ranges should be preserved");
}
other => panic!("Expected GROUP or STORE, got {:?}", other),
}
}
#[test]
fn test_load_vector_index_full_pipeline() {
let buffer = create_buffer(64);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let result = apply_devectorize(&load);
assert_eq!(count_ptrcats(&result), 0, "No PTRCAT should remain");
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_store_vector_index_full_pipeline() {
let buffer = create_buffer(64);
let value = create_vector_float_iota(4);
let index = create_vector_index_iota(buffer.clone(), 4);
let store = index.store(value);
let result = apply_devectorize(&store);
assert_eq!(count_ptrcats(&result), 0, "No PTRCAT should remain");
let store_count = count_stores(&result);
assert!(store_count >= 1, "Should have at least one store");
}
#[test]
fn test_split_load_divisibility() {
let buffer = create_buffer(128);
let idx = UOp::index()
.buffer(buffer.clone())
.indices(vec![UOp::const_(DType::Index, ConstValue::Int(8))]) .call()
.unwrap();
let vec8_ptr_dtype = DType::Float32.vec(8).ptr(Some(8), AddrSpace::Global);
let cast_idx = idx.cast(vec8_ptr_dtype);
let load_dtype = DType::Float32.vec(8);
let load = UOp::load().buffer(buffer.clone()).index(cast_idx).dtype(load_dtype).call();
let result = apply_devectorize(&load);
assert_eq!(result.dtype().vcount(), 8, "Total vcount should be 8");
match result.op() {
Op::Cat { sources } => {
assert!(sources.len() <= 4, "Should use fewer, larger chunks");
let total: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
assert_eq!(total, 8, "Total vcount should be 8");
}
Op::Load { .. } => {
assert_eq!(result.dtype().vcount(), 8, "Single load should have vcount 8");
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 8, "Vectorize should have 8 elements");
assert!(count_loads(&result) >= 1, "Should have at least one LOAD");
}
other => panic!("Expected CAT, LOAD, or VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_split_load_not_divisible() {
let buffer = create_buffer(128);
let idx = UOp::index()
.buffer(buffer.clone())
.indices(vec![UOp::const_(DType::Index, ConstValue::Int(3))]) .call()
.unwrap();
let vec8_ptr_dtype = DType::Float32.vec(8).ptr(Some(8), AddrSpace::Global);
let cast_idx = idx.cast(vec8_ptr_dtype);
let load_dtype = DType::Float32.vec(8);
let load = UOp::load().buffer(buffer.clone()).index(cast_idx).dtype(load_dtype).call();
let result = apply_devectorize(&load);
match result.op() {
Op::Cat { sources } => {
let total: usize = sources.iter().map(|s| s.dtype().vcount()).sum();
assert_eq!(total, 8, "Total vcount should be 8");
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 8, "Vectorize should have 8 elements");
}
Op::Load { .. } => {
assert_eq!(result.dtype().vcount(), 8, "Single load should have vcount 8");
}
other => panic!("Expected CAT, VECTORIZE, or LOAD, got {:?}", other),
}
}
#[test]
fn test_gated_index_load() {
let buffer = create_buffer(64);
let gate = create_bool_const(true);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let gated_index = UOp::new(
Op::Index { buffer: buffer.clone(), indices: smallvec::smallvec![idx], gate: Some(gate.clone()) },
DType::Float32,
);
let load = UOp::load().buffer(buffer.clone()).index(gated_index).call();
let result = apply_devectorize(&load);
assert!(
matches!(result.op(), Op::Load { .. } | Op::Cat { .. } | Op::Vectorize { .. }),
"Should produce valid load structure"
);
}