use std::sync::Arc;
use morok_dtype::ScalarDType;
use morok_ir::{Op, UOp};
use smallvec::smallvec;
use super::helpers::*;
#[test]
fn test_gep_after_load_single_index() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let gep_idx = idx.gep(vec![2]);
let load = UOp::load().buffer(buffer.clone()).index(gep_idx).call();
let result = apply_gep_movement(&load);
assert_eq!(result.dtype().vcount(), 1, "Single-index GEP extracts one element");
assert_eq!(result.dtype().base(), ScalarDType::Float32, "Base dtype preserved");
}
#[test]
fn test_gep_after_load_multi_index() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let gep_idx = idx.gep(vec![0, 2, 1]);
let load = UOp::load().buffer(buffer.clone()).index(gep_idx).call();
let result = apply_gep_movement(&load);
assert_eq!(result.dtype().vcount(), 3, "Multi-index GEP extracts 3 elements");
}
#[test]
fn test_gep_after_load_with_ptrcat() {
let buffer1 = create_buffer_typed(32, ScalarDType::Float32);
let buffer2 = create_buffer_typed(32, ScalarDType::Float32);
let idx1 = create_vector_index_iota(buffer1.clone(), 2);
let idx2 = create_vector_index_iota(buffer2.clone(), 2);
let ptrcat = UOp::ptrcat().sources(vec![idx1, idx2]).call();
let gep_ptrcat = ptrcat.gep(vec![0, 2]);
let load = UOp::load().buffer(buffer1.clone()).index(gep_ptrcat).call();
let result = apply_gep_movement(&load);
let ptrcat_under_load = count_ops(&result, |u| {
if let Op::Load { index, .. } = u.op() { matches!(index.op(), Op::PtrCat { .. }) } else { false }
});
assert_eq!(ptrcat_under_load, 0, "PTRCAT should be distributed through LOAD");
}
#[test]
fn test_gep_after_load_dtype_calculation() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let gep_idx = idx.gep(vec![0, 2]); let load = UOp::load().buffer(buffer.clone()).index(gep_idx).call();
let result = apply_gep_movement(&load);
assert_eq!(result.dtype().vcount(), 2);
assert_eq!(result.dtype().base(), ScalarDType::Float32);
}
#[test]
fn test_gep_after_load_preserves_buffer() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let gep_idx = idx.gep(vec![1, 3]);
let load = UOp::load().buffer(buffer.clone()).index(gep_idx).call();
let result = apply_gep_movement(&load);
fn check_load_dtype(uop: &Arc<UOp>) -> bool {
match uop.op() {
Op::Load { .. } => uop.dtype().base() == ScalarDType::Float32,
_ => {
for child in uop.op().children() {
if check_load_dtype(child) {
return true;
}
}
false
}
}
}
assert!(check_load_dtype(&result) || result.dtype().base() == ScalarDType::Float32);
}
#[test]
fn test_gep_after_load_identity_indices() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let gep_idx = idx.gep(vec![0, 1, 2, 3]); let load = UOp::load().buffer(buffer.clone()).index(gep_idx).call();
let result = apply_gep_movement(&load);
assert_eq!(result.dtype().vcount(), 4);
}
#[test]
fn test_gep_on_store_identity() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 2);
let value = create_vector_float_values(vec![1.0, 2.0]);
let gep_idx = idx.gep(vec![0, 1]); let store = gep_idx.store(value);
let result = apply_gep_movement(&store);
assert!(
matches!(result.op(), Op::Store { .. } | Op::Group { .. }),
"Expected STORE or GROUP, got {:?}",
result.op()
);
}
#[test]
fn test_gep_on_store_swap() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 2);
let value = create_vector_float_values(vec![1.0, 2.0]);
let gep_idx = idx.gep(vec![1, 0]); let store = gep_idx.store(value);
let result = apply_gep_movement(&store);
let inv = compute_inverse_permutation(&[1, 0]);
assert_eq!(inv, vec![1, 0], "Swap is self-inverse");
assert!(matches!(result.op(), Op::Store { .. } | Op::Group { .. }));
}
#[test]
fn test_gep_on_store_complex_permutation() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 3);
let value = create_vector_float_values(vec![1.0, 2.0, 3.0]);
let gep_idx = idx.gep(vec![2, 0, 1]);
let store = gep_idx.store(value);
let result = apply_gep_movement(&store);
let inv = compute_inverse_permutation(&[2, 0, 1]);
assert_eq!(inv, vec![1, 2, 0]);
assert!(matches!(result.op(), Op::Store { .. } | Op::Group { .. }));
}
#[test]
fn test_gep_on_store_preserves_ranges() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 2);
let value = create_vector_float_values(vec![1.0, 2.0]);
let range = create_range_loop(8, 0);
let gep_idx = idx.gep(vec![1, 0]);
let store = gep_idx.store_with_ranges(value, smallvec![range]);
let result = apply_gep_movement(&store);
fn has_ranges(uop: &Arc<UOp>) -> bool {
match uop.op() {
Op::Store { ranges, .. } => !ranges.is_empty(),
Op::Group { sources } => sources.iter().all(has_ranges),
_ => false,
}
}
assert!(has_ranges(&result), "Ranges should be preserved");
}
#[test]
fn test_gep_on_store_4_element() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let value = create_vector_float_iota(4);
let gep_idx = idx.gep(vec![3, 1, 2, 0]);
let store = gep_idx.store(value);
let result = apply_gep_movement(&store);
let inv = compute_inverse_permutation(&[3, 1, 2, 0]);
assert_eq!(inv, vec![3, 1, 2, 0]);
assert!(matches!(result.op(), Op::Store { .. } | Op::Group { .. }));
}
#[test]
fn test_gep_on_store_with_ptrcat() {
let buffer1 = create_buffer_typed(32, ScalarDType::Float32);
let buffer2 = create_buffer_typed(32, ScalarDType::Float32);
let idx1 = create_vector_index_iota(buffer1.clone(), 2);
let idx2 = create_vector_index_iota(buffer2.clone(), 2);
let ptrcat = UOp::ptrcat().sources(vec![idx1, idx2]).call();
let value = create_vector_float_iota(4);
let gep_ptrcat = ptrcat.gep(vec![0, 2, 1, 3]);
let store = gep_ptrcat.store(value);
let result = apply_gep_movement(&store);
assert!(
matches!(result.op(), Op::Store { .. } | Op::Group { .. }),
"Expected STORE or GROUP after PTRCAT distribution"
);
}
#[test]
fn test_gep_movement_in_folding() {
let buffer = create_buffer_typed(64, ScalarDType::Float32);
let idx = create_vector_index_iota(buffer.clone(), 4);
let gep_idx = idx.gep(vec![0, 2]);
let load = UOp::load().buffer(buffer.clone()).index(gep_idx.clone()).call();
let store = gep_idx.store(load.clone());
let folded = apply_load_store_folding(&store);
let corrected = apply_correct_load_store(&folded);
let result = apply_pm_render(&corrected);
assert!(matches!(result.op(), Op::Store { .. } | Op::Group { .. }));
}
#[test]
fn test_gep_movement_enables_ptrcat_distribution() {
let buffer1 = create_buffer_typed(32, ScalarDType::Float32);
let buffer2 = create_buffer_typed(32, ScalarDType::Float32);
let idx1 = create_vector_index_iota(buffer1.clone(), 2);
let idx2 = create_vector_index_iota(buffer2.clone(), 2);
let ptrcat = UOp::ptrcat().sources(vec![idx1, idx2]).call();
let gep_ptrcat = ptrcat.gep(vec![0, 1, 2, 3]);
let load = UOp::load().buffer(buffer1.clone()).index(gep_ptrcat).call();
let folded = apply_load_store_folding(&load);
let corrected = apply_correct_load_store(&folded);
let result = apply_pm_render(&corrected);
let ptrcat_under_load = count_ops(&result, |u| {
if let Op::Load { index, .. } = u.op() { matches!(index.op(), Op::PtrCat { .. }) } else { false }
});
assert_eq!(ptrcat_under_load, 0, "PTRCAT should be distributed");
}
#[test]
fn test_inverse_permutation_identity() {
let inv = compute_inverse_permutation(&[0, 1, 2, 3]);
assert_eq!(inv, vec![0, 1, 2, 3]);
}
#[test]
fn test_inverse_permutation_reverse() {
let inv = compute_inverse_permutation(&[3, 2, 1, 0]);
assert_eq!(inv, vec![3, 2, 1, 0], "Reverse is self-inverse");
}
#[test]
fn test_inverse_permutation_rotation() {
let inv = compute_inverse_permutation(&[1, 2, 3, 0]);
assert_eq!(inv, vec![3, 0, 1, 2]);
}
#[test]
fn test_inverse_permutation_complex() {
let inv = compute_inverse_permutation(&[2, 0, 3, 1]);
assert_eq!(inv, vec![1, 3, 0, 2]);
}