use std::sync::Arc;
use morok_dtype::{AddrSpace, DType};
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use smallvec::{SmallVec, smallvec};
use crate::devectorize::devectorize;
use super::helpers::*;
#[test]
fn test_cast_after_basic() {
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let cast = src.cast(DType::Float64);
let after = cast.after(smallvec![]);
let result = apply_cast_after(&after);
match result.op() {
Op::Cast { src: inner, dtype } => {
assert_eq!(*dtype, DType::Float64);
match inner.op() {
Op::After { passthrough, deps } => {
assert_eq!(passthrough.dtype(), DType::Float32);
assert!(deps.is_empty());
}
other => panic!("Expected AFTER inside CAST, got {:?}", other),
}
}
other => panic!("Expected CAST, got {:?}", other),
}
}
#[test]
fn test_cast_after_with_deps() {
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let dep = UOp::const_(DType::Int32, ConstValue::Int(42));
let cast = src.cast(DType::Float64);
let after = cast.after(smallvec![dep.clone()]);
let result = apply_cast_after(&after);
let Op::Cast { src: inner, .. } = result.op() else {
panic!("Expected CAST");
};
let Op::After { deps, .. } = inner.op() else {
panic!("Expected AFTER");
};
assert_eq!(deps.len(), 1);
}
#[test]
fn test_after_without_cast_unchanged() {
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let after = src.after(smallvec![]);
let result = apply_cast_after(&after);
assert!(matches!(result.op(), Op::After { .. }));
let Op::After { passthrough, .. } = result.op() else { unreachable!() };
assert!(matches!(passthrough.op(), Op::Const(_)));
}
#[test]
fn test_drop_true_gate() {
let buffer = create_buffer(64);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let gate = UOp::const_(DType::Bool, ConstValue::Bool(true));
let gated_index = UOp::new(
Op::Index { buffer: buffer.clone(), indices: smallvec![idx.clone()], gate: Some(gate) },
DType::Float32,
);
let result = apply_load_store_indexing(&gated_index);
match result.op() {
Op::Index { gate, .. } => {
assert!(gate.is_none(), "Gate should be dropped");
}
other => panic!("Expected INDEX, got {:?}", other),
}
}
#[test]
fn test_false_gate_unchanged() {
let buffer = create_buffer(64);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let gate = UOp::const_(DType::Bool, ConstValue::Bool(false));
let gated_index = UOp::new(
Op::Index { buffer: buffer.clone(), indices: smallvec![idx.clone()], gate: Some(gate.clone()) },
DType::Float32,
);
let result = apply_load_store_indexing(&gated_index);
match result.op() {
Op::Index { gate: g, .. } => {
assert!(g.is_some(), "False gate should not be dropped");
}
other => panic!("Expected INDEX, got {:?}", other),
}
}
#[test]
fn test_no_gate_unchanged() {
let buffer = create_buffer(64);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let index = UOp::index().buffer(buffer).indices(vec![idx]).call().unwrap();
let result = apply_load_store_indexing(&index);
match result.op() {
Op::Index { gate, .. } => {
assert!(gate.is_none());
}
other => panic!("Expected INDEX, got {:?}", other),
}
}
#[test]
fn test_devectorize_define_local_vec4() {
let vec_ptr_dtype = DType::Float32.vec(4).ptr(Some(16), AddrSpace::Local);
let def_local = UOp::define_local(0, vec_ptr_dtype);
let result = apply_cast_after(&def_local);
match result.op() {
Op::Cast { src, dtype } => {
assert!(matches!(dtype, DType::Ptr { base, .. } if base.vcount() == 4));
assert!(matches!(src.op(), Op::DefineLocal(_)));
let DType::Ptr { base: inner_base, .. } = src.dtype() else { panic!("Expected Ptr dtype") };
assert_eq!(inner_base.vcount(), 1, "Inner should have scalar base");
}
Op::DefineLocal(_) => {
}
other => panic!("Expected CAST or DEFINE_LOCAL, got {:?}", other),
}
}
#[test]
fn test_define_local_scalar_unchanged() {
let scalar_ptr_dtype = DType::Float32.ptr(Some(16), AddrSpace::Local);
let def_local = UOp::define_local(0, scalar_ptr_dtype);
let result = apply_cast_after(&def_local);
assert!(matches!(result.op(), Op::DefineLocal(_)));
}
#[test]
fn test_full_devectorize_simple_load() {
let buffer = create_buffer(64);
let index = create_vector_index_iota(buffer.clone(), 4);
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
let load_count = count_loads(&result);
assert!(load_count >= 1, "Should have at least one LOAD in the result");
}
#[test]
fn test_devectorize_non_contiguous() {
let buffer = create_buffer(64);
let index = create_vector_index_scaled(buffer.clone(), 4, 2); let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_devectorize(&load);
assert!(result.dtype().vcount() >= 1);
}
#[test]
fn test_cast_after_in_full_pipeline() {
let buffer = create_buffer(64);
let src = UOp::const_(DType::Float32, ConstValue::Float(1.0));
let cast = src.cast(DType::Float64);
let after = cast.after(smallvec![]);
let idx = create_index(buffer.clone(), 0);
let load = UOp::load().buffer(buffer).index(idx).call();
let sink = UOp::sink(vec![after, load]);
let result = apply_devectorize(&sink);
assert!(matches!(result.op(), Op::Sink { .. }));
}
#[test]
fn test_gate_dropping_in_full_pipeline() {
let buffer = create_buffer(64);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let gate = UOp::const_(DType::Bool, ConstValue::Bool(true));
let gated_index =
UOp::new(Op::Index { buffer: buffer.clone(), indices: smallvec![idx], gate: Some(gate) }, DType::Float32);
let load = UOp::load().buffer(buffer).index(gated_index).call();
let result = apply_devectorize(&load);
assert!(count_loads(&result) >= 1);
}
#[test]
fn test_gated_load_gets_alt() {
let buffer = create_buffer(64);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let gate = UOp::const_(DType::Bool, ConstValue::Bool(false));
let gated_index =
UOp::new(Op::Index { buffer: buffer.clone(), indices: smallvec![idx], gate: Some(gate) }, DType::Float32);
let load = UOp::load().buffer(buffer).index(gated_index).call();
let result = apply_pm_render(&load);
match result.op() {
Op::Load { alt, .. } => {
assert!(alt.is_some(), "Gated LOAD should have alt value after pm_render");
if let Some(alt_val) = alt {
let is_zero = match alt_val.op() {
Op::Const(cv) => {
matches!(cv.0, ConstValue::Int(0)) || matches!(cv.0, ConstValue::Float(f) if f == 0.0)
}
_ => false,
};
assert!(is_zero, "Alt value should be 0");
}
}
other => {
tracing::debug!("Gated load transformed to: {:?}", other);
}
}
}
#[test]
fn test_ungate_load_unchanged() {
let buffer = create_buffer(64);
let idx = UOp::const_(DType::Index, ConstValue::Int(0));
let index = UOp::index().buffer(buffer.clone()).indices(vec![idx]).call().unwrap();
let load = UOp::load().buffer(buffer).index(index).call();
let result = apply_pm_render(&load);
if let Op::Load { alt, .. } = result.op() {
assert!(alt.is_none(), "Ungated LOAD should not have alt value");
}
}
#[test]
fn test_is_increasing_range() {
use morok_ir::types::{AxisId, AxisType};
let range = UOp::range_axis(UOp::index_const(16), AxisId::Unrenumbered(0), AxisType::Loop);
assert!(range.is_increasing(), "RANGE should be increasing");
}
#[test]
fn test_is_increasing_constant() {
let c = UOp::const_(DType::Int32, ConstValue::Int(5));
assert!(c.is_increasing(), "CONST should be increasing");
}
#[test]
fn test_is_increasing_add_expr() {
use morok_ir::types::{AxisId, AxisType};
let range = UOp::range_axis(UOp::index_const(16), AxisId::Unrenumbered(0), AxisType::Loop);
let c = UOp::const_(DType::Index, ConstValue::Int(5));
let sum = range.try_add(&c).unwrap();
assert!(sum.is_increasing(), "RANGE + CONST should be increasing");
}
#[test]
fn test_is_increasing_mul_positive() {
use morok_ir::types::{AxisId, AxisType};
let range = UOp::range_axis(UOp::index_const(16), AxisId::Unrenumbered(0), AxisType::Loop);
let c = UOp::const_(DType::Index, ConstValue::Int(4));
let prod = range.try_mul(&c).unwrap();
assert!(prod.is_increasing(), "RANGE * positive CONST should be increasing");
}
#[test]
fn test_is_increasing_mul_negative() {
let x = UOp::var("x", DType::Int32, 0, 100);
let c = UOp::const_(DType::Int32, ConstValue::Int(-1));
let prod = x.try_mul(&c).unwrap();
assert!(!prod.is_increasing(), "x * negative CONST should not be increasing");
}
#[test]
fn test_devectorize_local_buffer_vector_index() {
let vec3_ptr_dtype = DType::Float32.vec(3).ptr(Some(9), AddrSpace::Local);
let _def_local = UOp::define_local(0, vec3_ptr_dtype.clone());
let scalar_ptr_dtype = DType::Float32.ptr(Some(9), AddrSpace::Local);
let scalar_def = UOp::define_local(1, scalar_ptr_dtype);
let cast_def = scalar_def.cast(vec3_ptr_dtype);
let idx0 = UOp::index_const(0);
let idx1 = UOp::index_const(1);
let idx2 = UOp::index_const(2);
let vec3_idx = UOp::vectorize(smallvec![idx0, idx1, idx2]);
let index = UOp::new(
Op::Index { buffer: cast_def, indices: smallvec![vec3_idx], gate: None },
DType::Float32.vec(9).ptr(Some(9), AddrSpace::Local),
);
let value_elements: SmallVec<[Arc<UOp>; 4]> =
(0..9).map(|i| UOp::const_(DType::Float32, ConstValue::Float(i as f64))).collect();
let vec9_value = UOp::vectorize(value_elements);
let store = index.store(vec9_value);
let sink = UOp::sink(vec![store]);
let result = devectorize(&sink);
let has_vector_local_index = result.toposort().iter().any(|node: &Arc<UOp>| {
if let Op::Index { buffer, indices, .. } = node.op() {
let has_vec_idx = indices.first().is_some_and(|i| i.dtype().vcount() > 1);
let is_local_cast = matches!(buffer.op(), Op::Cast { src, .. }
if matches!(src.op(), Op::DefineLocal(_)));
has_vec_idx && is_local_cast
} else {
false
}
});
assert!(!has_vector_local_index, "After devectorize, no INDEX(CAST(DEF_LOCAL), vec_idx) should remain");
let store_count = count_stores(&result);
assert!(store_count >= 3, "Expected at least 3 stores (grouped), got {store_count}");
}
#[test]
fn test_devectorize_local_buffer_vec9_index() {
let vec9_ptr_dtype = DType::Float32.vec(9).ptr(Some(81), AddrSpace::Local);
let scalar_ptr_dtype = DType::Float32.ptr(Some(81), AddrSpace::Local);
let scalar_def = UOp::define_local(2, scalar_ptr_dtype);
let cast_def = scalar_def.cast(vec9_ptr_dtype);
let vec9_idx = UOp::vectorize((0..9i64).map(UOp::index_const).collect::<SmallVec<[Arc<UOp>; 4]>>());
let index = UOp::new(
Op::Index { buffer: cast_def, indices: smallvec![vec9_idx], gate: None },
DType::Float32.vec(81).ptr(Some(81), AddrSpace::Local),
);
let value_elements: SmallVec<[Arc<UOp>; 4]> =
(0..81).map(|i| UOp::const_(DType::Float32, ConstValue::Float(i as f64))).collect();
let vec81_value = UOp::vectorize(value_elements);
let store = index.store(vec81_value);
let sink = UOp::sink(vec![store]);
let result = devectorize(&sink);
let has_vector_local_index = result.toposort().iter().any(|node: &Arc<UOp>| {
if let Op::Index { buffer, indices, .. } = node.op() {
let has_vec_idx = indices.first().is_some_and(|i| i.dtype().vcount() > 1);
let is_local_cast = matches!(buffer.op(), Op::Cast { src, .. }
if matches!(src.op(), Op::DefineLocal(_)));
has_vec_idx && is_local_cast
} else {
false
}
});
assert!(!has_vector_local_index, "After devectorize, no INDEX(CAST(DEF_LOCAL), vec_idx) should remain");
let store_count = count_stores(&result);
assert!(store_count >= 9, "Expected at least 9 stores (grouped), got {store_count}");
}