use std::sync::Arc;
use morok_ir::UOp;
use crate::rangeify::transforms::{flatten_range_impl, flatten_ranges};
#[test]
fn test_flatten_range_impl_non_supported_op() {
let const_op = UOp::native_const(1.0f32);
let result = flatten_range_impl(&const_op);
assert!(result.is_none());
}
#[test]
fn test_flatten_range_impl_no_ranges() {
let index = UOp::index_const(0);
let value = UOp::native_const(1.0f32);
let store = index.store(value);
let result = flatten_range_impl(&store);
assert!(result.is_none());
}
#[test]
fn test_flatten_ranges_identity() {
let computation = UOp::native_const(1.0f32);
let flattened = flatten_ranges(&computation);
assert!(Arc::ptr_eq(&flattened, &computation));
}
#[test]
fn test_flatten_range_nested_end() {
use morok_ir::Op;
use smallvec::smallvec;
let computation = UOp::native_const(1.0f32);
let r1 = UOp::range(UOp::index_const(10), 0);
let r2 = UOp::range(UOp::index_const(20), 1);
let inner_end = computation.clone().end(smallvec![r1.clone()]);
let outer_end = inner_end.end(smallvec![r2.clone()]);
let flattened = flatten_range_impl(&outer_end);
assert!(flattened.is_some(), "Nested END should be flattened");
let flattened = flattened.unwrap();
if let Op::End { ranges, .. } = flattened.op() {
assert_eq!(ranges.len(), 2, "Should have 2 ranges after flattening");
} else {
panic!("Expected END operation");
}
}
#[test]
fn test_flatten_range_deeply_nested() {
use morok_ir::Op;
use smallvec::smallvec;
let computation = UOp::native_const(1.0f32);
let r1 = UOp::range(UOp::index_const(10), 0);
let r2 = UOp::range(UOp::index_const(20), 1);
let r3 = UOp::range(UOp::index_const(30), 2);
let end1 = computation.clone().end(smallvec![r1.clone()]);
let end2 = end1.end(smallvec![r2.clone()]);
let end3 = end2.end(smallvec![r3.clone()]);
let flattened = flatten_range_impl(&end3);
assert!(flattened.is_some(), "Deeply nested END should be flattened");
let flattened = flattened.unwrap();
if let Op::End { ranges, .. } = flattened.op() {
assert_eq!(ranges.len(), 3, "Should have 3 ranges after deep flattening");
} else {
panic!("Expected END operation");
}
}
#[test]
fn test_flatten_range_preserves_computation() {
use morok_ir::Op;
use smallvec::smallvec;
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let add = a.try_add(&b).unwrap();
let r1 = UOp::range(UOp::index_const(10), 0);
let r2 = UOp::range(UOp::index_const(20), 1);
let inner_end = add.clone().end(smallvec![r1.clone()]);
let outer_end = inner_end.end(smallvec![r2.clone()]);
let flattened = flatten_range_impl(&outer_end);
assert!(flattened.is_some());
let flattened = flattened.unwrap();
if let Op::End { computation, ranges } = flattened.op() {
assert_eq!(ranges.len(), 2);
if let Op::End { computation: inner_comp, .. } = computation.op() {
assert!(matches!(inner_comp.op(), Op::Binary(..)));
}
} else {
panic!("Expected END operation");
}
}
#[test]
fn test_flatten_ranges_full_graph() {
use morok_ir::Op;
use smallvec::smallvec;
let computation = UOp::native_const(1.0f32);
let r1 = UOp::range(UOp::index_const(10), 0);
let r2 = UOp::range(UOp::index_const(20), 1);
let inner_end = computation.clone().end(smallvec![r1.clone()]);
let outer_end = inner_end.end(smallvec![r2.clone()]);
let flattened = flatten_ranges(&outer_end);
assert!(!Arc::ptr_eq(&flattened, &outer_end), "Graph should be transformed");
assert!(matches!(flattened.op(), Op::End { .. }));
}
#[test]
fn test_flatten_range_single_range() {
use smallvec::smallvec;
let computation = UOp::native_const(1.0f32);
let r1 = UOp::range(UOp::index_const(10), 0);
let end = computation.clone().end(smallvec![r1.clone()]);
let flattened = flatten_range_impl(&end);
assert!(flattened.is_none());
}