use morok_dtype::ScalarDType;
use morok_ir::Op;
use super::helpers::*;
#[test]
fn test_bool_load_to_uint8() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0);
let load = create_load(buffer.clone(), index);
assert_eq!(load.dtype().base(), ScalarDType::Bool);
let result = apply_bool_storage(&load);
match result.op() {
Op::Cast { src, dtype } => {
assert_eq!(dtype.base(), ScalarDType::Bool, "CAST should produce bool");
assert_is_load(src);
assert_eq!(src.dtype().base(), ScalarDType::UInt8, "Inner LOAD should be uint8");
}
Op::Load { .. } => {
assert_eq!(result.dtype().base(), ScalarDType::Bool, "LOAD result should be bool");
}
other => panic!("Expected CAST(LOAD) or LOAD, got {:?}", other),
}
}
#[test]
fn test_non_bool_load_unchanged() {
let buffer = create_buffer(64); let index = create_index(buffer.clone(), 0);
let load = create_load(buffer.clone(), index);
assert_eq!(load.dtype().base(), ScalarDType::Float32);
let result = apply_bool_storage(&load);
assert_is_load(&result);
assert_eq!(result.dtype().base(), ScalarDType::Float32);
}
#[test]
fn test_int32_load_unchanged() {
let buffer = create_buffer_typed(64, ScalarDType::Int32);
let index = create_index(buffer.clone(), 0);
let load = create_load(buffer.clone(), index);
let result = apply_bool_storage(&load);
assert_is_load(&result);
assert_eq!(result.dtype().base(), ScalarDType::Int32);
}
#[test]
fn test_bool_store_to_uint8() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0);
let bool_val = create_bool_const(true);
let store = create_store(index, bool_val);
let result = apply_bool_storage(&store);
match result.op() {
Op::Store { value, .. } => {
match value.op() {
Op::Cast { src, dtype } => {
assert_eq!(dtype.base(), ScalarDType::UInt8);
assert_eq!(src.dtype().base(), ScalarDType::Bool);
}
Op::Const(_) => {}
other => panic!("Expected CAST or Const value, got {:?}", other),
}
}
other => panic!("Expected STORE, got {:?}", other),
}
}
#[test]
fn test_non_bool_store_unchanged() {
let buffer = create_buffer(64);
let index = create_index(buffer.clone(), 0);
let float_val = create_float_const(3.0);
let store = create_store(index, float_val.clone());
let result = apply_bool_storage(&store);
match result.op() {
Op::Store { value, .. } => {
assert_eq!(value.dtype().base(), ScalarDType::Float32);
}
other => panic!("Expected STORE, got {:?}", other),
}
}
#[test]
fn test_bool_roundtrip() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0);
let bool_val = create_bool_const(true);
let store = create_store(index.clone(), bool_val);
let store_result = apply_bool_storage(&store);
let load = create_load(buffer.clone(), index);
let load_result = apply_bool_storage(&load);
if let Op::Store { value, .. } = store_result.op() {
assert!(matches!(value.op(), Op::Cast { .. } | Op::Const(_)));
}
if let Op::Cast { dtype, .. } = load_result.op() {
assert_eq!(dtype.base(), ScalarDType::Bool);
}
}
#[test]
fn test_bool_with_devectorize() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0);
let load = create_load(buffer.clone(), index);
let result = apply_devectorize(&load);
assert!(
result.dtype().base() == ScalarDType::Bool || result.dtype().base() == ScalarDType::UInt8,
"Result should be bool or uint8"
);
}
#[test]
fn test_vector_bool_load() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0);
let load = create_load(buffer.clone(), index);
let result = apply_bool_storage(&load);
match result.op() {
Op::Cast { src, dtype } => {
assert_eq!(dtype.base(), ScalarDType::Bool);
assert_eq!(src.dtype().base(), ScalarDType::UInt8);
}
Op::Load { .. } => {}
other => panic!("Expected CAST(LOAD) or LOAD, got {:?}", other),
}
}
#[test]
fn test_vector_bool_store() {
let buffer = create_bool_buffer(64);
let index = create_index(buffer.clone(), 0);
let bool_vec = create_vector_bool(vec![true, false, true, false]);
let store = create_store(index, bool_vec);
let result = apply_bool_storage(&store);
if let Op::Store { value, .. } = result.op() {
match value.op() {
Op::Cast { dtype, .. } => {
assert_eq!(dtype.base(), ScalarDType::UInt8);
}
Op::Vectorize { .. } => {
}
_ => {}
}
}
}