use std::sync::Arc;
use morok_dtype::{AddrSpace, DType, DeviceSpec};
use morok_ir::{Op, UOp};
use crate::rangeify::patterns::extract_device_from_graph;
#[test]
fn test_extract_device_from_buffer() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
if let Some(device) = extract_device_from_graph(&buffer) {
assert_eq!(device, DeviceSpec::Cpu);
}
}
#[test]
fn test_device_uop_creation() {
let device = UOp::device(DeviceSpec::Cpu);
if let Op::Device(spec) = device.op() {
assert_eq!(*spec, DeviceSpec::Cpu);
} else {
panic!("Expected Device op");
}
}
#[test]
fn test_addrspace_global() {
let addrspace = AddrSpace::Global;
assert_eq!(addrspace, AddrSpace::Global);
}
#[test]
fn test_addrspace_local() {
let addrspace = AddrSpace::Local;
assert_eq!(addrspace, AddrSpace::Local);
}
#[test]
fn test_device_spec_equality() {
let cpu1 = DeviceSpec::Cpu;
let cpu2 = DeviceSpec::Cpu;
assert_eq!(cpu1, cpu2, "Same device specs should be equal");
}
#[test]
fn test_multiple_buffers_same_device() {
let a = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let b = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
if let Op::Buffer { device: dev_a, .. } = a.op()
&& let Op::Buffer { device: dev_b, .. } = b.op()
&& let (Op::Device(spec_a), Op::Device(spec_b)) = (dev_a.op(), dev_b.op())
{
assert_eq!(spec_a, spec_b, "Same device type should match");
}
}
#[test]
fn test_device_propagation_through_ops() {
let a = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let b = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let add = a.try_add(&b).unwrap();
let topo = add.toposort();
let buffer_count = topo.iter().filter(|u| matches!(u.op(), Op::Buffer { .. })).count();
assert!(buffer_count >= 2, "Should have buffer nodes from inputs");
for node in &topo {
if let Op::Buffer { device, .. } = node.op() {
assert!(matches!(device.op(), Op::Device(_)), "Buffer should have device");
}
}
}
#[test]
fn test_buffer_view_inherits_device() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let view = buffer.view(50, 10);
if let Op::BufferView { buffer: ref_buf, .. } = view.op() {
assert!(Arc::ptr_eq(ref_buf, &buffer));
} else {
panic!("Expected BufferView op");
}
}
#[test]
fn test_constant_no_device() {
let c = UOp::native_const(42.0f32);
let device = extract_device_from_graph(&c);
let _ = device;
}
#[test]
fn test_device_spec_debug() {
let cpu = DeviceSpec::Cpu;
let debug_str = format!("{:?}", cpu);
assert!(debug_str.contains("Cpu"), "Debug should contain device name");
}