use std::sync::Arc;
use morok_dtype::{DType, ScalarDType};
use morok_ir::types::ConstValue;
use morok_ir::{AxisId, AxisType, BinaryOp, Op, UOp};
use super::helpers::*;
#[test]
fn test_devectorize_contiguous_load() {
let buffer = create_buffer(64);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let result = apply_devectorize(&load);
assert_vcount(&result, 4);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_devectorize_contiguous_store() {
let buffer = create_buffer(64);
let value = create_vector_float_iota(4);
let index = create_vector_index_iota(buffer.clone(), 4);
let store = index.store(value);
let result = apply_devectorize(&store);
let store_count = count_stores(&result);
assert!(store_count >= 1, "Should have at least one STORE");
}
#[test]
fn test_devectorize_strided_load() {
let buffer = create_buffer(128);
let index = create_vector_index_scaled(buffer.clone(), 4, 2);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let result = apply_devectorize(&load);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have LOADs for strided access");
}
#[test]
fn test_devectorize_matmul_pattern() {
use crate::devectorize::{devectorize, pm_render};
use crate::rewrite::graph_rewrite;
let buffer = create_buffer(256);
let index = create_vector_index_iota(buffer.clone(), 8);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let after_devectorize = devectorize(&load);
let result = graph_rewrite(pm_render(), after_devectorize, &mut ());
assert_eq!(result.dtype().vcount(), 8, "Total vcount should be 8");
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_devectorize_reduction_accumulator() {
let buffer = create_buffer(64);
let acc_index = create_vector_index_iota(buffer.clone(), 4);
let acc_load = UOp::load().buffer(buffer.clone()).index(acc_index).call();
let values = create_vector_float_iota(4);
let add = UOp::new(Op::Binary(BinaryOp::Add, acc_load, values), DType::Float32.vec(4));
let store_index = create_vector_index_iota(buffer.clone(), 4);
let store = store_index.store(add);
let result = apply_devectorize(&store);
let load_count = count_loads(&result);
let store_count = count_stores(&result);
assert!(load_count >= 1 && store_count >= 1);
}
#[test]
fn test_devectorize_multiple_buffers() {
let buffer_a = create_buffer(64);
let buffer_b = create_buffer(64);
let buffer_c = create_buffer(64);
let index_a = create_vector_index_iota(buffer_a.clone(), 4);
let load_a = UOp::load().buffer(buffer_a.clone()).index(index_a).call();
let index_b = create_vector_index_iota(buffer_b.clone(), 4);
let load_b = UOp::load().buffer(buffer_b.clone()).index(index_b).call();
let add = UOp::new(Op::Binary(BinaryOp::Add, load_a, load_b), DType::Float32.vec(4));
let index_c = create_vector_index_iota(buffer_c.clone(), 4);
let store = index_c.store(add);
let result = apply_devectorize(&store);
let load_count = count_loads(&result);
let store_count = count_stores(&result);
assert!(load_count >= 2, "Should have LOADs from both A and B");
assert!(store_count >= 1, "Should have STORE to C");
}
#[test]
fn test_devectorize_after_pre_expand() {
let buffer = create_buffer(64);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let value = create_vector_float_iota(4);
let add = UOp::new(Op::Binary(BinaryOp::Add, load, value), DType::Float32.vec(4));
let store_index = create_vector_index_iota(buffer.clone(), 4);
let store = store_index.store(add);
let result = apply_devectorize(&store);
assert!(count_stores(&result) >= 1);
}
#[test]
fn test_devectorize_with_output_upcast() {
let buffer = create_buffer(256);
let index = create_vector_index_iota(buffer.clone(), 8);
let value = create_vector_float_iota(8);
let store = index.store(value);
let result = apply_devectorize(&store);
let store_count = count_stores(&result);
assert!(store_count >= 1);
}
#[test]
fn test_devectorize_loop_index() {
let buffer = create_buffer(256);
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(20000);
let def_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let define = UOp::param(def_id, 256, buffer.dtype(), None);
let buf_vec = define.broadcast(4);
let range = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(64)),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let base = UOp::new(Op::Binary(BinaryOp::Mul, range, UOp::const_(DType::Index, ConstValue::Int(4))), DType::Index);
let indices: smallvec::SmallVec<[Arc<UOp>; 4]> = (0..4)
.map(|i| {
if i == 0 {
base.clone()
} else {
UOp::new(
Op::Binary(BinaryOp::Add, base.clone(), UOp::const_(DType::Index, ConstValue::Int(i))),
DType::Index,
)
}
})
.collect();
let vec_idx = UOp::vectorize(indices);
let index =
UOp::new(Op::Index { buffer: buf_vec, indices: smallvec::smallvec![vec_idx], gate: None }, DType::Float32);
let load = UOp::load().buffer(define).index(index).call();
let result = apply_devectorize(&load);
assert_eq!(result.dtype().vcount(), 4, "Total vcount should be 4");
assert!(count_loads(&result) >= 1, "Should have at least one LOAD");
}
#[test]
fn test_devectorize_sink_multiple_stores() {
let buffer_a = create_buffer(64);
let buffer_b = create_buffer(64);
let index_a = create_vector_index_iota(buffer_a.clone(), 4);
let value_a = create_vector_float_iota(4);
let store_a = index_a.store(value_a);
let index_b = create_vector_index_iota(buffer_b.clone(), 4);
let value_b = create_vector_float_values(vec![10.0, 11.0, 12.0, 13.0]);
let store_b = index_b.store(value_b);
let sink = UOp::sink(vec![store_a, store_b]);
let result = apply_devectorize(&sink);
let store_count = count_stores(&result);
assert!(store_count >= 2, "Should have stores from both operations");
}
#[test]
fn test_devectorize_float16() {
let buffer = create_buffer_typed(64, ScalarDType::Float16);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let result = apply_devectorize(&load);
assert_eq!(result.dtype().base(), ScalarDType::Float16);
}
#[test]
fn test_devectorize_int32() {
let buffer = create_buffer_typed(64, ScalarDType::Int32);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer.clone()).index(index).call();
let result = apply_devectorize(&load);
assert_eq!(result.dtype().base(), ScalarDType::Int32);
}
#[test]
fn test_devectorize_bool_pipeline() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0); let load = UOp::load().buffer(buffer.clone()).index(index).call();
let result = apply_devectorize(&load);
assert!(
result.dtype().base() == ScalarDType::Bool || result.dtype().base() == ScalarDType::UInt8,
"Bool should be handled correctly"
);
}