use morok_dtype::DType;
use morok_dtype::DeviceSpec;
use crate::types::{AddrSpace, AxisId, AxisType, BufferizeOpts};
use crate::{Op, UOp};
#[test]
fn test_bufferize() {
let compute = UOp::native_const(1.0f32);
let r1 = UOp::range_axis(UOp::native_const(10i32), AxisId::Renumbered(0), AxisType::Loop);
let r2 = UOp::range_axis(UOp::native_const(20i32), AxisId::Renumbered(1), AxisType::Loop);
let opts = BufferizeOpts::new(DeviceSpec::Cpu);
let bufferize = UOp::bufferize(compute.clone(), vec![r1, r2], opts);
assert_eq!(bufferize.dtype(), DType::Float32);
if let Op::Bufferize { compute: c, ranges, opts: o } = bufferize.op() {
assert!(std::sync::Arc::ptr_eq(c, &compute));
assert_eq!(ranges.len(), 2);
assert_eq!(o.device, Some(DeviceSpec::Cpu));
assert_eq!(o.addrspace, AddrSpace::Global);
} else {
panic!("Expected Bufferize op");
}
}
#[test]
fn test_bufferize_local() {
let compute = UOp::native_const(1.0f32);
let r = UOp::range_axis(UOp::native_const(10i32), AxisId::Renumbered(0), AxisType::Loop);
let opts = BufferizeOpts::local();
let bufferize = UOp::bufferize(compute, vec![r], opts);
if let Op::Bufferize { opts: o, .. } = bufferize.op() {
assert_eq!(o.addrspace, AddrSpace::Local);
} else {
panic!("Expected Bufferize op");
}
}
#[test]
fn test_load() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let index = UOp::index_const(0);
let load = UOp::load().buffer(buffer.clone()).index(index.clone()).call();
assert_eq!(load.dtype(), DType::Float32);
if let Op::Load { buffer: b, index: i, .. } = load.op() {
assert!(std::sync::Arc::ptr_eq(b, &buffer));
assert!(std::sync::Arc::ptr_eq(i, &index));
} else {
panic!("Expected Load op");
}
}
#[test]
fn test_store() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let index_offset = UOp::index_const(0);
let value = UOp::native_const(42.0f32);
let index = UOp::index().buffer(buffer.clone()).indices(vec![index_offset]).call().unwrap();
let store = index.store_value(value.clone());
assert_eq!(store.dtype(), DType::Void);
if let Op::Store { index: i, value: v, .. } = store.op() {
assert!(std::sync::Arc::ptr_eq(i, &index));
assert!(std::sync::Arc::ptr_eq(v, &value));
assert!(std::sync::Arc::ptr_eq(store.store_buffer().unwrap(), &buffer));
} else {
panic!("Expected Store op");
}
}
#[test]
fn test_codegen_param() {
let p = UOp::param(0, 1024, DType::Float32.ptr(Some(1024), morok_dtype::AddrSpace::Global), None);
assert!(matches!(p.dtype(), DType::Ptr { .. }));
if let Op::Param { slot, size, device } = p.op() {
assert_eq!(*slot, 0);
assert_eq!(*size, 1024);
assert!(device.is_none());
} else {
panic!("Expected Param op");
}
}
#[test]
fn test_define_local() {
let dl = UOp::define_local(1, DType::Int32);
assert_eq!(dl.dtype(), DType::Int32);
if let Op::DefineLocal(id) = dl.op() {
assert_eq!(*id, 1);
} else {
panic!("Expected DefineLocal op");
}
}