use std::sync::Arc;
use morok_ir::{UOp, UOpKey};
#[test]
fn test_identical_const_dedup() {
let c1 = UOp::native_const(42.0f32);
let c2 = UOp::native_const(42.0f32);
assert!(Arc::ptr_eq(&c1, &c2), "Identical constants should be deduplicated");
}
#[test]
fn test_identical_binary_op_dedup() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add1 = a.try_add(&b).unwrap();
let add2 = a.try_add(&b).unwrap();
assert!(Arc::ptr_eq(&add1, &add2), "Identical binary ops should be deduplicated");
}
#[test]
fn test_different_binary_op_not_dedup() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let mul = a.try_mul(&b).unwrap();
assert!(!Arc::ptr_eq(&add, &mul), "Different ops should not be deduplicated");
}
#[test]
fn test_uopkey_equality() {
let a = UOp::native_const(42.0f32);
let b = a.clone();
let c = UOp::native_const(42.0f32);
let key_a = UOpKey(a.clone());
let key_b = UOpKey(b);
let key_c = UOpKey(c);
assert_eq!(key_a, key_b, "Clone should have same key");
assert_eq!(key_a, key_c, "Same value should have same key (hash consing)");
}
#[test]
fn test_uopkey_hash_consistency() {
use std::collections::HashMap;
let a = UOp::native_const(42.0f32);
#[allow(clippy::mutable_key_type)]
let mut map: HashMap<UOpKey, i32> = HashMap::new();
map.insert(UOpKey(a.clone()), 100);
assert_eq!(map.get(&UOpKey(a.clone())), Some(&100));
let a2 = UOp::native_const(42.0f32);
assert_eq!(map.get(&UOpKey(a2)), Some(&100));
}
#[test]
fn test_diamond_pattern_dedup() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let add1 = a.try_add(&b).unwrap();
let add2 = a.try_add(&c).unwrap();
let sink = UOp::sink(vec![add1, add2]);
let topo = sink.toposort();
let a_count = topo.iter().filter(|u| Arc::ptr_eq(u, &a)).count();
assert_eq!(a_count, 1, "Shared input should appear once in toposort");
}
#[test]
fn test_reused_intermediate() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let sum = a.try_add(&b).unwrap();
let double = sum.try_add(&sum).unwrap();
let topo = double.toposort();
let sum_count = topo.iter().filter(|u| Arc::ptr_eq(u, &sum)).count();
assert_eq!(sum_count, 1, "Reused intermediate should appear once");
}
#[test]
fn test_cache_different_dtypes() {
let a_f32 = UOp::native_const(1.0f32);
let b_f32 = UOp::native_const(2.0f32);
let a_f64 = UOp::native_const(1.0f64);
let b_f64 = UOp::native_const(2.0f64);
let add_f32 = a_f32.try_add(&b_f32).unwrap();
let add_f64 = a_f64.try_add(&b_f64).unwrap();
assert!(!Arc::ptr_eq(&add_f32, &add_f64), "Different dtypes should not deduplicate");
assert_ne!(add_f32.dtype(), add_f64.dtype());
}
#[test]
fn test_cache_order_matters() {
let a = UOp::native_const(3.0f32);
let b = UOp::native_const(1.0f32);
let sub1 = a.try_sub(&b).unwrap();
let sub2 = b.try_sub(&a).unwrap();
assert!(!Arc::ptr_eq(&sub1, &sub2), "Order should matter for non-commutative ops");
}
#[test]
fn test_toposort_no_duplicates() {
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let ab = a.try_add(&b).unwrap();
let abc = ab.try_add(&c).unwrap();
let result = abc.try_mul(&ab).unwrap();
let topo = result.toposort();
let unique_count = topo.len();
let mut seen: Vec<Arc<UOp>> = Vec::new();
for node in &topo {
if !seen.iter().any(|s| Arc::ptr_eq(s, node)) {
seen.push(node.clone());
}
}
assert_eq!(seen.len(), unique_count, "Toposort should have no duplicates");
}