use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{Op, TernaryOp, UOp};
use super::helpers::*;
#[test]
fn test_cat_vec4_to_vectorize() {
let a = create_vector_float_iota(4);
let b = create_vector_float_values(vec![10.0, 11.0, 12.0, 13.0]);
let cat = UOp::cat().sources(vec![a, b]).call();
assert_vcount(&cat, 8);
let result = apply_pm_render(&cat);
match result.op() {
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 8, "Should have 8 elements");
for elem in elements.iter() {
assert_eq!(elem.dtype().vcount(), 1, "Each element should be scalar");
}
}
Op::Cat { sources } => {
assert_eq!(sources.len(), 8);
}
other => panic!("Expected VECTORIZE or CAT, got {:?}", other),
}
}
#[test]
fn test_cat_scalar_unchanged() {
let a = create_float_const(1.0);
let b = create_float_const(2.0);
let c = create_float_const(3.0);
let d = create_float_const(4.0);
let cat = UOp::cat().sources(vec![a, b, c, d]).call();
let result = apply_pm_render(&cat);
match result.op() {
Op::Cat { sources } => {
assert_eq!(sources.len(), 4);
for src in sources.iter() {
assert_eq!(src.dtype().vcount(), 1);
}
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 4);
}
other => panic!("Expected CAT or VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_cat_single_source_unwrap() {
let a = create_vector_float_iota(4);
let cat = UOp::cat().sources(vec![a.clone()]).call();
let result = apply_pm_render(&cat);
assert!(Arc::ptr_eq(&result, &a), "Single-source CAT should unwrap");
}
#[test]
fn test_gep_vectorize_single() {
let e0 = create_float_const(0.0);
let e1 = create_float_const(1.0);
let e2 = create_float_const(2.0);
let vec = UOp::vectorize([e0, e1.clone(), e2].into_iter().collect());
let gep = vec.gep(vec![1]);
let result = apply_pm_render(&gep);
assert_eq!(result.dtype().vcount(), 1, "Should be scalar");
match result.op() {
Op::Const(v) => {
assert_eq!(v.0, ConstValue::Float(1.0), "Should extract value 1.0");
}
Op::Gep { indices, .. } => {
assert_eq!(indices.len(), 1, "Should have single index");
assert_eq!(indices[0], 1, "Index should be 1");
}
other => panic!("Expected Const or GEP, got {:?}", other),
}
}
#[test]
fn test_gep_vectorize_multi() {
let elements: smallvec::SmallVec<[Arc<UOp>; 4]> = (0..4).map(|i| create_float_const(i as f64)).collect();
let vec = UOp::vectorize(elements);
let gep = vec.gep(vec![0, 2]);
let result = apply_pm_render(&gep);
assert_vcount(&result, 2);
match result.op() {
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 2);
}
other => panic!("Expected VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_gep_broadcast_extraction() {
let x = create_float_const(42.0);
let vec = x.broadcast(4);
let gep = vec.gep(vec![2]);
let result = apply_pm_render(&gep);
assert_eq!(result.dtype().vcount(), 1, "Should be scalar");
match result.op() {
Op::Const(v) => {
assert_eq!(v.0, ConstValue::Float(42.0), "Should extract value 42.0");
}
Op::Gep { indices, .. } => {
assert_eq!(indices.len(), 1, "Should have single index");
}
other => panic!("Expected Const or GEP, got {:?}", other),
}
}
#[test]
fn test_gep_cat_reorder() {
let a = create_float_const(1.0);
let b = create_float_const(2.0);
let c = create_float_const(3.0);
let cat = UOp::cat().sources(vec![a, b.clone(), c.clone()]).call();
let gep = cat.gep(vec![1, 2]);
let result = apply_pm_render(&gep);
assert_vcount(&result, 2);
match result.op() {
Op::Cat { sources } => {
assert_eq!(sources.len(), 2);
}
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 2);
}
other => panic!("Expected CAT or VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_gep_cat_single() {
let a = create_float_const(1.0);
let b = create_float_const(2.0);
let c = create_float_const(3.0);
let cat = UOp::cat().sources(vec![a, b.clone(), c]).call();
let gep = cat.gep(vec![1]);
let result = apply_pm_render(&gep);
assert_eq!(result.dtype().vcount(), 1);
}
#[test]
fn test_ptrcat_single_unwrap() {
let buffer = create_buffer(64);
let p = create_index(buffer.clone(), 0);
let ptrcat = UOp::ptrcat().sources(vec![p.clone()]).call();
let result = apply_pm_render(&ptrcat);
assert_is_index(&result);
}
#[test]
fn test_cat_gep_identity() {
let x = create_vector_float_iota(4);
let geps: Vec<Arc<UOp>> = (0..4).map(|i| x.gep(vec![i])).collect();
let cat = UOp::cat().sources(geps).call();
let result = apply_pm_render(&cat);
assert_vcount(&result, 4);
}
#[test]
fn test_where_devectorize() {
let cond = create_vector_bool(vec![true, false, true, false]);
let t_val = create_vector_float_iota(4);
let f_val = create_vector_float_values(vec![10.0, 11.0, 12.0, 13.0]);
let where_op = UOp::new(Op::Ternary(TernaryOp::Where, cond, t_val, f_val), DType::Float32.vec(4));
let result = apply_pm_render(&where_op);
assert_eq!(result.dtype().vcount(), 4, "Result vcount should be 4");
match result.op() {
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 4, "Should have 4 scalar WHEREs");
for elem in elements.iter() {
assert!(matches!(elem.op(), Op::Ternary(TernaryOp::Where, _, _, _)), "Each element should be WHERE");
assert_eq!(elem.dtype().vcount(), 1, "Each WHERE should be scalar");
}
}
Op::Ternary(TernaryOp::Where, c, t, f) => {
assert_eq!(c.dtype().vcount(), 4, "Condition should be vec4");
assert_eq!(t.dtype().vcount(), 4, "True value should be vec4");
assert_eq!(f.dtype().vcount(), 4, "False value should be vec4");
}
other => panic!("Expected VECTORIZE or WHERE, got {:?}", other),
}
}
#[test]
fn test_where_scalar_unchanged() {
let cond = create_bool_const(true);
let t_val = create_float_const(1.0);
let f_val = create_float_const(0.0);
let where_op = UOp::new(Op::Ternary(TernaryOp::Where, cond, t_val, f_val), DType::Float32);
let result = apply_pm_render(&where_op);
assert!(matches!(result.op(), Op::Ternary(TernaryOp::Where, _, _, _)), "Scalar WHERE should remain unchanged");
assert_eq!(result.dtype().vcount(), 1);
}
#[test]
fn test_gep_through_cast() {
let vec = create_vector_float_iota(4);
let cast = vec.cast(DType::Int64.vec(4));
let gep = cast.gep(vec![1]);
let result = apply_pm_render(&gep);
assert_eq!(result.dtype().vcount(), 1);
}
#[test]
fn test_multi_index_gep_normalizes() {
let x = create_vector_float_iota(8);
let gep = x.gep(vec![0, 1, 2, 3]);
let result = apply_vectorize_normalize(&gep);
match result.op() {
Op::Vectorize { elements } => {
assert_eq!(elements.len(), 4);
for elem in elements.iter() {
if let Op::Gep { indices, .. } = elem.op() {
assert_eq!(indices.len(), 1, "Each GEP should be single-index");
}
}
}
other => panic!("Expected VECTORIZE, got {:?}", other),
}
}
#[test]
fn test_gep_scalar_identity() {
let scalar = create_float_const(42.0);
let gep = scalar.gep(vec![0]);
let result = apply_vectorize_normalize(&gep);
assert!(Arc::ptr_eq(&result, &scalar) || result.dtype().vcount() == 1);
}
#[test]
fn test_single_element_vectorize_unwrap() {
let x = create_float_const(42.0);
let vec = UOp::vectorize([x.clone()].into_iter().collect());
let result = apply_vectorize_normalize(&vec);
assert!(Arc::ptr_eq(&result, &x), "Single-element VECTORIZE should unwrap");
}
#[test]
#[should_panic]
fn test_empty_ptrcat_panics() {
let _ptrcat = UOp::ptrcat().sources(vec![]).call();
}
#[test]
#[should_panic]
fn test_empty_cat_panics() {
let _cat = UOp::cat().sources(vec![]).call();
}
#[test]
fn test_gep_out_of_bounds() {
let vec = create_vector_float_iota(4);
let gep = vec.gep(vec![10]);
let _result = apply_pm_render(&gep);
}