use std::sync::Arc;
use morok_device::DeviceSpec;
use morok_dtype::DType;
use morok_ir::{UOp, UOpKey};
use crate::rangeify::transforms::{OpAccessType, as_buf, find_bufs};
#[test]
fn test_find_bufs_store_only() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store_idx = UOp::index().buffer(buffer.clone()).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(value);
#[allow(clippy::mutable_key_type)]
let buf_accesses = find_bufs(&store);
assert_eq!(buf_accesses.len(), 1);
let buf_key = UOpKey(buffer.clone());
assert_eq!(buf_accesses.get(&buf_key), Some(&OpAccessType::Store));
}
#[test]
fn test_find_bufs_load_only() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let load_idx = UOp::index().buffer(buffer.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let loaded = UOp::load().buffer(buffer.clone()).index(load_idx).call();
let out_buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let store_idx = UOp::index().buffer(out_buffer.clone()).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(loaded);
#[allow(clippy::mutable_key_type)]
let buf_accesses = find_bufs(&store);
assert_eq!(buf_accesses.len(), 2);
let in_buf_key = UOpKey(buffer.clone());
let out_buf_key = UOpKey(out_buffer.clone());
assert_eq!(buf_accesses.get(&in_buf_key), Some(&OpAccessType::Load));
assert_eq!(buf_accesses.get(&out_buf_key), Some(&OpAccessType::Store));
}
#[test]
#[should_panic(expected = "buffer accessed with conflicting ops")]
fn test_find_bufs_conflicting_access() {
let buffer = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let load_idx = UOp::index().buffer(buffer.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let loaded = UOp::load().buffer(buffer.clone()).index(load_idx).call();
let store_idx = UOp::index().buffer(buffer.clone()).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(loaded);
find_bufs(&store);
}
#[test]
fn test_find_bufs_multiple_buffers() {
let buf1 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let buf2 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let const_idx = UOp::index_const(0);
let load1_idx = UOp::index().buffer(buf1.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let load1 = UOp::load().buffer(buf1.clone()).index(load1_idx).call();
let load2_idx = UOp::index().buffer(buf2.clone()).indices(vec![const_idx.clone()]).call().unwrap();
let load2 = UOp::load().buffer(buf2.clone()).index(load2_idx).call();
let sum = load1.try_add(&load2).unwrap();
let store_idx = UOp::index().buffer(out_buf.clone()).indices(vec![const_idx]).call().unwrap();
let store = store_idx.store(sum);
#[allow(clippy::mutable_key_type)]
let buf_accesses = find_bufs(&store);
assert_eq!(buf_accesses.len(), 3);
assert_eq!(buf_accesses.get(&UOpKey(buf1.clone())), Some(&OpAccessType::Load));
assert_eq!(buf_accesses.get(&UOpKey(buf2.clone())), Some(&OpAccessType::Load));
assert_eq!(buf_accesses.get(&UOpKey(out_buf.clone())), Some(&OpAccessType::Store));
}
#[test]
fn test_find_bufs_with_gated_index() {
let in_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let out_buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let gate = UOp::native_const(true);
let gated_in_index =
UOp::index().buffer(in_buf.clone()).indices(vec![UOp::index_const(0)]).gate(gate.clone()).call().unwrap();
let loaded = UOp::load().buffer(in_buf.clone()).index(gated_in_index).call();
let gated_out_index =
UOp::index().buffer(out_buf.clone()).indices(vec![UOp::index_const(0)]).gate(gate).call().unwrap();
let store = gated_out_index.store(loaded);
#[allow(clippy::mutable_key_type)]
let buf_accesses = find_bufs(&store);
assert_eq!(buf_accesses.len(), 2);
assert_eq!(buf_accesses.get(&UOpKey(in_buf.clone())), Some(&OpAccessType::Load));
assert_eq!(buf_accesses.get(&UOpKey(out_buf.clone())), Some(&OpAccessType::Store));
}
#[test]
fn test_as_buf_mselect() {
let buffer = UOp::buffer_id(Some(0));
let mselect = buffer.mselect(0);
let extracted = as_buf(&mselect);
assert!(Arc::ptr_eq(&extracted, &buffer));
}
#[test]
fn test_as_buf_mstack() {
let buf1 = UOp::buffer_id(Some(1));
let buf2 = UOp::buffer_id(Some(2));
let mstack = UOp::mstack(vec![buf1.clone(), buf2].into());
let extracted = as_buf(&mstack);
assert!(Arc::ptr_eq(&extracted, &buf1));
}
#[test]
fn test_as_buf_after() {
let buffer = UOp::buffer_id(Some(0));
let computation = UOp::noop();
let after = buffer.after(smallvec::smallvec![computation]);
let extracted = as_buf(&after);
assert!(Arc::ptr_eq(&extracted, &buffer));
}