#![allow(clippy::mutable_key_type)]
use std::collections::HashMap;
use std::sync::Arc;
use morok_dtype::DType;
use morok_dtype::DeviceSpec;
use crate::{AxisId, ConstValue, Op, UOp};
#[test]
fn test_const_creation() {
let c1 = UOp::native_const(1.0f32);
assert_eq!(c1.dtype(), DType::Float32);
assert!(matches!(c1.op(), Op::Const(_)));
}
#[test]
fn test_hash_consing() {
let c1 = UOp::native_const(1.0f32);
let c2 = UOp::native_const(1.0f32);
assert!(Arc::ptr_eq(&c1, &c2), "Hash consing should return same Rc for identical UOps");
}
#[test]
fn test_hash_consing_with_src() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add1 = a.try_add(&b).unwrap();
let add2 = a.try_add(&b).unwrap();
assert!(Arc::ptr_eq(&add1, &add2), "Hash consing should work with src nodes");
}
#[test]
fn test_cross_thread_hash_consing() {
use std::sync::Barrier;
let num_threads = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let b = Arc::clone(&barrier);
std::thread::spawn(move || {
b.wait();
UOp::native_const(42.0f32)
})
})
.collect();
let uops: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
for i in 1..uops.len() {
assert!(
Arc::ptr_eq(&uops[0], &uops[i]),
"Thread {} got different Arc than thread 0 (id {} vs {})",
i,
uops[i].id,
uops[0].id
);
}
}
#[test]
fn test_cross_thread_hash_consing_complex() {
use std::sync::Barrier;
let num_threads = 8;
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let b = Arc::clone(&barrier);
std::thread::spawn(move || {
b.wait();
let a = UOp::native_const(1.0f32);
let b_val = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let add = a.try_add(&b_val).unwrap();
add.try_mul(&c).unwrap()
})
})
.collect();
let uops: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
for i in 1..uops.len() {
assert!(Arc::ptr_eq(&uops[0], &uops[i]), "Thread {} got different Arc for complex expression", i);
}
}
#[test]
fn test_binary_operations() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
assert_eq!(add.dtype(), DType::Float32);
assert_eq!(add.op().children().len(), 2);
let mul = a.try_mul(&b).unwrap();
assert_eq!(mul.dtype(), DType::Float32);
}
#[test]
fn test_unary_operations() {
let a = UOp::native_const(4.0f32);
let sqrt = a.try_sqrt().unwrap();
assert_eq!(sqrt.dtype(), DType::Float32);
assert_eq!(sqrt.op().children().len(), 1);
}
#[test]
fn test_cast() {
let a = UOp::native_const(1.5f32);
let cast = a.cast(DType::Int32);
assert_eq!(cast.dtype(), DType::Int32);
}
#[test]
fn test_comparison() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let cmp = a.try_cmplt(&b).unwrap();
assert_eq!(cmp.dtype(), DType::Bool);
}
#[test]
fn test_toposort() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let add = a.try_add(&b).unwrap();
let mul = add.try_mul(&c).unwrap();
let sorted = mul.toposort();
assert!(sorted.len() >= 5);
let positions: HashMap<_, _> = sorted.iter().enumerate().map(|(i, node)| (Arc::as_ptr(node), i)).collect();
for node in &sorted {
let node_pos = positions[&Arc::as_ptr(node)];
for child in node.op().children() {
let child_pos = positions[&Arc::as_ptr(child)];
assert!(child_pos < node_pos, "Dependencies must come before dependents");
}
}
}
#[test]
fn test_toposort_shared_node() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let x = a.try_add(&b).unwrap();
let y = a.try_add(&c).unwrap();
let z = x.try_mul(&y).unwrap();
let sorted = z.toposort();
let a_ptr = Arc::as_ptr(&a);
let a_count = sorted.iter().filter(|node| Arc::as_ptr(node) == a_ptr).count();
assert_eq!(a_count, 1, "Shared node 'a' should appear exactly once");
}
#[test]
fn test_buffer_creation() {
let buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
assert!(matches!(buf.op(), Op::Buffer { .. }));
assert_eq!(buf.dtype(), DType::Float32);
if let Op::Buffer { size, .. } = buf.op() {
assert_eq!(*size, 100);
} else {
panic!("Expected Buffer op");
}
}
#[test]
fn test_buffer_hash_consing() {
let buf1 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let buf2 = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
assert!(!Arc::ptr_eq(&buf1, &buf2), "Different buffers should have different UNIQUE ids");
}
#[test]
fn test_buffer_view() {
let buf = UOp::new_buffer(DeviceSpec::Cpu, 1000, DType::Float32);
let view = buf.view(100, 50);
assert!(matches!(view.op(), Op::BufferView { .. }));
assert_eq!(view.dtype(), DType::Float32);
if let Op::BufferView { size, offset, .. } = view.op() {
assert_eq!(*size, 100);
assert_eq!(*offset, 50);
} else {
panic!("Expected BufferView op");
}
}
#[test]
fn test_index_operation() {
let buf = UOp::new_buffer(DeviceSpec::Cpu, 100, DType::Float32);
let idx = UOp::const_(DType::Index, ConstValue::UInt(10));
let indexed = UOp::index().buffer(buf).indices(vec![idx]).call().expect("index should succeed");
assert!(matches!(indexed.op(), Op::Index { .. }));
assert_eq!(indexed.op().children().len(), 2); }
#[test]
fn test_device_and_unique() {
let dev = UOp::device(DeviceSpec::Cpu);
assert!(matches!(dev.op(), Op::Device(_)));
if let Op::Device(spec) = dev.op() {
assert_eq!(*spec, DeviceSpec::Cpu);
}
let uniq = UOp::buffer_id(Some(42));
assert!(matches!(uniq.op(), Op::Unique(42)));
let uniq_auto = UOp::buffer_id(None);
assert!(matches!(uniq_auto.op(), Op::Unique(_)));
}
#[test]
fn test_children_method() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let children = add.op().children();
assert_eq!(children.len(), 2);
assert!(Arc::ptr_eq(children[0], &a));
assert!(Arc::ptr_eq(children[1], &b));
}
#[test]
fn test_for_each_child() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let mut children = Vec::new();
add.op().map_child(|child| children.push(child.clone()));
assert_eq!(children.len(), 2);
assert!(Arc::ptr_eq(&children[0], &a));
assert!(Arc::ptr_eq(&children[1], &b));
}
#[test]
fn test_shape_property_scalar() {
let scalar = UOp::native_const(42.0f32);
let shape = scalar.shape().unwrap();
assert!(shape.is_some(), "Scalar should have shape");
assert_eq!(shape.unwrap().len(), 0, "Scalar should have empty shape");
}
#[test]
fn test_shape_property_lazy_evaluation() {
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::ShapeProperty;
let unique_val = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos() as f64;
let a = UOp::native_const(unique_val as f32);
let b = UOp::native_const((unique_val + 1.0) as f32);
let add = a.try_add(&b).unwrap();
assert!(ShapeProperty::cache(&add).get().is_none(), "Cache should be empty before first access");
let shape1 = ShapeProperty::get(&add);
assert!(shape1.is_ok() && shape1.as_ref().unwrap().is_some());
assert!(ShapeProperty::cache(&add).get().is_some(), "Cache should be populated after first access");
let shape2 = ShapeProperty::get(&add);
assert!(std::ptr::eq(shape1, shape2), "Second access should return same cached reference");
}
#[test]
fn test_ranges_property_no_ranges() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let ranges = add.ranges();
assert_eq!(ranges.len(), 0, "No RANGE ops in simple arithmetic");
}
#[test]
fn test_ranges_property_with_range() {
use crate::AxisType;
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let idx = range.cast(DType::Float32);
let ranges = idx.ranges();
assert_eq!(ranges.len(), 1, "Should find one RANGE op");
assert!(Arc::ptr_eq(&ranges[0], &range));
}
#[test]
fn test_ranges_property_lazy_evaluation() {
use crate::AxisType;
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::RangesProperty;
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let idx = range.cast(DType::Float32);
assert!(RangesProperty::cache(&idx).get().is_none(), "Cache should be empty before first access");
let ranges1 = RangesProperty::get(&idx);
assert_eq!(ranges1.len(), 1);
assert!(RangesProperty::cache(&idx).get().is_some(), "Cache should be populated after first access");
let ranges2 = RangesProperty::get(&idx);
assert!(std::ptr::eq(ranges1, ranges2), "Second access should return same cached reference");
assert!(Arc::ptr_eq(&ranges1[0], &ranges2[0]));
}
#[test]
fn test_in_scope_ranges_simple() {
use crate::AxisType;
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let in_scope = range.in_scope_ranges();
assert_eq!(in_scope.len(), 1, "RANGE should have itself in scope");
let idx = range.cast(DType::Float32);
let in_scope_idx = idx.in_scope_ranges();
assert_eq!(in_scope_idx.len(), 1, "Computation should inherit RANGE scope");
}
#[test]
fn test_in_scope_ranges_lazy_evaluation() {
use crate::AxisType;
use crate::uop::cached_property::CachedProperty;
use crate::uop::properties::InScopeRangesProperty;
let end = UOp::index_const(10);
let range = UOp::range_axis(end, AxisId::Renumbered(0), AxisType::Loop);
let idx = range.cast(DType::Float32);
assert!(InScopeRangesProperty::cache(&idx).get().is_none(), "Cache should be empty before first access");
let in_scope1 = InScopeRangesProperty::get(&idx);
assert_eq!(in_scope1.len(), 1);
assert!(InScopeRangesProperty::cache(&idx).get().is_some(), "Cache should be populated after first access");
let in_scope2 = InScopeRangesProperty::get(&idx);
assert!(std::ptr::eq(in_scope1, in_scope2), "Second access should return same cached reference");
}
#[test]
fn test_in_scope_ranges_after_end() {
use crate::AxisType;
use smallvec::smallvec;
let end_val = UOp::index_const(10);
let range = UOp::range_axis(end_val, AxisId::Renumbered(0), AxisType::Loop);
let compute = UOp::native_const(1.0f32);
let end_op = compute.end(smallvec![range.clone()]);
let in_scope = end_op.in_scope_ranges();
assert_eq!(in_scope.len(), 0, "After END, range should not be in scope");
}
#[test]
fn test_in_scope_ranges_nested() {
use crate::AxisType;
use smallvec::smallvec;
let end1 = UOp::index_const(10);
let _range1 = UOp::range_axis(end1, AxisId::Renumbered(0), AxisType::Loop);
let end2 = UOp::index_const(20);
let range2 = UOp::range_axis(end2, AxisId::Renumbered(1), AxisType::Loop);
let compute = UOp::native_const(1.0f32);
let in_scope = compute.in_scope_ranges();
assert_eq!(in_scope.len(), 0, "Const has no ranges in scope initially");
let after_end2 = compute.end(smallvec![range2.clone()]);
let in_scope_after = after_end2.in_scope_ranges();
assert_eq!(in_scope_after.len(), 0, "After END, ranges are not propagated to parent");
}
#[test]
fn test_toposort_filtered_basic() {
let a = UOp::native_const(1.0f32);
let b = a.try_add(&UOp::native_const(2.0f32)).unwrap();
let c = b.try_mul(&UOp::native_const(3.0f32)).unwrap();
let filtered = c.toposort_filtered(|node| Arc::ptr_eq(node, &c));
assert_eq!(filtered.len(), 1, "Filtered toposort should only include nodes passing gate");
assert!(Arc::ptr_eq(&filtered[0], &c));
}
#[test]
fn test_toposort_filtered_all() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let filtered = add.toposort_filtered(|_| true);
let regular = add.toposort();
assert_eq!(filtered.len(), regular.len());
}
#[test]
fn test_toposort_filtered_none() {
let a = UOp::native_const(1.0f32);
let filtered = a.toposort_filtered(|_| false);
assert_eq!(filtered.len(), 0, "Gate blocking all nodes should return empty");
}
#[test]
fn test_multiple_properties_coexist() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let shape = add.shape().unwrap();
assert!(shape.is_some());
assert_eq!(shape.unwrap().len(), 0);
let ranges = add.ranges();
assert_eq!(ranges.len(), 0);
let in_scope = add.in_scope_ranges();
assert_eq!(in_scope.len(), 0);
let shape2 = add.shape().unwrap();
let ranges2 = add.ranges();
let in_scope2 = add.in_scope_ranges();
assert_eq!(shape, shape2);
assert_eq!(ranges.len(), ranges2.len());
assert_eq!(in_scope.len(), in_scope2.len());
}