use std::sync::Arc;
use morok_dtype::{DType, ScalarDType};
use morok_ir::types::ConstValue;
use morok_ir::{AxisId, AxisType, BinaryOp, Op, UOp};
use smallvec::smallvec;
use super::helpers::*;
fn assert_no_ptrcat(uop: &Arc<UOp>) {
let ptrcat_count = count_ptrcats(uop);
assert_eq!(ptrcat_count, 0, "No PTRCAT nodes should remain after full devectorize, found {}", ptrcat_count);
}
#[test]
fn test_expand_contiguous_vec4() {
let buffer = create_buffer(64);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_contiguous_vec8() {
let buffer = create_buffer(128);
let index = create_vector_index_iota(buffer.clone(), 8);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_contiguous_with_offset() {
let buffer = create_buffer(64);
let index = create_vector_index_offset(buffer.clone(), 4, 10);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_contiguous_preserves_buffer() {
let buffer = create_buffer(64);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let define_count = count_ops(&result, |u| matches!(u.op(), Op::Param { device: None, .. }));
assert!(define_count > 0, "Codegen PARAM reference should be present");
}
#[test]
fn test_expand_strided_access() {
let buffer = create_buffer(64);
let index = create_vector_index_scaled(buffer.clone(), 4, 2);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert_eq!(load_count, 4, "Strided access should produce 4 scalar LOADs");
}
#[test]
fn test_expand_mixed_groups() {
let buffer = create_buffer(64);
let index = create_vector_index_values(buffer.clone(), vec![0, 1, 5, 6]);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!((1..=4).contains(&load_count), "Should have between 1 and 4 LOADs, got {}", load_count);
}
#[test]
fn test_expand_reversed_indices() {
let buffer = create_buffer(64);
let index = create_vector_index_values(buffer.clone(), vec![3, 2, 1, 0]);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_scattered_indices() {
let buffer = create_buffer(64);
let index = create_vector_index_values(buffer.clone(), vec![0, 5, 3, 7]);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert_eq!(load_count, 4, "Scattered access should produce 4 scalar LOADs");
}
#[test]
fn test_expand_scalar_index_no_change() {
let buffer = create_buffer(64);
let index = create_index(buffer.clone(), 5);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
assert_is_load(&result);
assert_eq!(result.dtype().vcount(), 1, "Should remain scalar");
}
#[test]
fn test_expand_gated_index() {
let buffer = create_buffer(64);
let gate = create_bool_const(true);
let index = create_vector_index_gated(buffer.clone(), 4, gate);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_multi_index_unsupported() {
let buffer = create_buffer(64);
let idx1 = UOp::const_(DType::Index, ConstValue::Int(0));
let idx2 = UOp::const_(DType::Index, ConstValue::Int(1));
let index =
UOp::new(Op::Index { buffer: buffer.clone(), indices: smallvec![idx1, idx2], gate: None }, DType::Float32);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_is_load(&result);
}
#[test]
fn test_expand_range_based_index() {
let buffer = create_buffer(256);
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(1000);
let def_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let define = UOp::param(def_id, 256, buffer.dtype(), None);
let buf_vec = define.broadcast(4);
let range = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(64)),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let four = UOp::const_(DType::Index, ConstValue::Int(4));
let base = UOp::new(Op::Binary(BinaryOp::Mul, range, four), DType::Index);
let indices: smallvec::SmallVec<[Arc<UOp>; 4]> = (0..4)
.map(|i| {
if i == 0 {
base.clone()
} else {
UOp::new(
Op::Binary(BinaryOp::Add, base.clone(), UOp::const_(DType::Index, ConstValue::Int(i))),
DType::Index,
)
}
})
.collect();
let vec_idx = UOp::vectorize(indices);
let index = UOp::new(Op::Index { buffer: buf_vec, indices: smallvec![vec_idx], gate: None }, DType::Float32);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_symbolic_root_grouping() {
let buffer = create_buffer(256);
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(2000);
let def_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let define = UOp::param(def_id, 256, buffer.dtype(), None);
let buf_vec = define.broadcast(4);
let range = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(64)),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let indices: smallvec::SmallVec<[Arc<UOp>; 4]> = [0i64, 1, 10, 11]
.iter()
.map(|&offset| {
if offset == 0 {
range.clone()
} else {
UOp::new(
Op::Binary(BinaryOp::Add, range.clone(), UOp::const_(DType::Index, ConstValue::Int(offset))),
DType::Index,
)
}
})
.collect();
let vec_idx = UOp::vectorize(indices);
let index = UOp::new(Op::Index { buffer: buf_vec, indices: smallvec![vec_idx], gate: None }, DType::Float32);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 2, "Should have at least 2 LOADs for 2 groups, got {}", load_count);
}
#[test]
fn test_expand_different_roots_separate() {
let buffer = create_buffer(256);
static COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(3000);
let def_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let define = UOp::param(def_id, 256, buffer.dtype(), None);
let buf_vec = define.broadcast(4);
let range1 = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(64)),
axis_id: AxisId::Renumbered(0),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let range2 = UOp::new(
Op::Range {
end: UOp::const_(DType::Index, ConstValue::Int(64)),
axis_id: AxisId::Renumbered(1),
axis_type: AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Index,
);
let indices: smallvec::SmallVec<[Arc<UOp>; 4]> = [
range1.clone(),
UOp::new(
Op::Binary(BinaryOp::Add, range1.clone(), UOp::const_(DType::Index, ConstValue::Int(1))),
DType::Index,
),
range2.clone(),
UOp::new(
Op::Binary(BinaryOp::Add, range2.clone(), UOp::const_(DType::Index, ConstValue::Int(1))),
DType::Index,
),
]
.into_iter()
.collect();
let vec_idx = UOp::vectorize(indices);
let index = UOp::new(Op::Index { buffer: buf_vec, indices: smallvec![vec_idx], gate: None }, DType::Float32);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 2, "Different roots should produce at least 2 LOADs, got {}", load_count);
}
#[test]
fn test_expand_int32_buffer() {
let buffer = create_buffer_typed(64, ScalarDType::Int32);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_half_buffer() {
let buffer = create_buffer_typed(64, ScalarDType::Float16);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_pure_broadcast() {
let buffer = create_buffer(64);
let index = create_vector_index_values(buffer.clone(), vec![0, 0, 0, 0]);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_partial_broadcast() {
let buffer = create_buffer(64);
let index = create_vector_index_values(buffer.clone(), vec![0, 1, 0, 1]);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}
#[test]
fn test_expand_mixed_broadcast() {
let buffer = create_buffer(64);
let index = create_vector_index_values(buffer.clone(), vec![0, 1, 0, 2]);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert_no_ptrcat(&result);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD");
}