use std::collections::HashSet;
use std::sync::Arc;
use morok_dtype::{AddrSpace, DType};
use morok_ir::types::{AxisId, AxisType};
use morok_ir::{Op, UOp};
fn create_local_range(end_value: i64, axis_id: usize) -> Arc<UOp> {
let end = UOp::const_(DType::Index, morok_ir::types::ConstValue::Int(end_value));
UOp::range_axis(end, AxisId::Renumbered(axis_id), AxisType::Local)
}
fn create_global_buffer(buf_id: usize) -> Arc<UOp> {
UOp::param(buf_id, 1024, DType::Float32.ptr(Some(1024), AddrSpace::Global), None)
}
fn create_index(buffer: Arc<UOp>, indices: Vec<Arc<UOp>>) -> Arc<UOp> {
UOp::index().buffer(buffer).indices(indices).call().expect("index should succeed")
}
#[test]
fn test_in_scope_ranges_basic() {
let range = create_local_range(16, 0);
let value = range.add(&UOp::index_const(1));
#[allow(clippy::mutable_key_type)]
let in_scope = value.in_scope_ranges();
assert!(in_scope.iter().any(|key| key.0.id == range.id), "Range should be in scope: found {:?}", in_scope);
}
#[test]
fn test_in_scope_ranges_after_end() {
let range = create_local_range(16, 0);
let computation = range.add(&UOp::index_const(1));
let ended = computation.end(smallvec::smallvec![range.clone()]);
#[allow(clippy::mutable_key_type)]
let in_scope = ended.in_scope_ranges();
assert!(
!in_scope.iter().any(|key| key.0.id == range.id),
"Range should NOT be in scope after END: found {:?}",
in_scope
);
}
#[test]
fn test_in_scope_ranges_partial_end() {
let range1 = create_local_range(16, 0);
let range2 = create_local_range(32, 1);
let computation = range1.add(&range2);
let after_end1 = computation.end(smallvec::smallvec![range1.clone()]);
let another_computation = range2.add(&UOp::index_const(5));
let final_comp = another_computation.after(smallvec::smallvec![after_end1]);
#[allow(clippy::mutable_key_type)]
let in_scope = final_comp.in_scope_ranges();
assert!(!in_scope.iter().any(|key| key.0.id == range1.id), "range1 should NOT be in scope");
assert!(in_scope.iter().any(|key| key.0.id == range2.id), "range2 should be in scope");
}
#[test]
fn test_toposort_vs_in_scope_difference() {
let range = create_local_range(16, 0);
let computation = range.add(&UOp::index_const(42));
let ended = computation.end(smallvec::smallvec![range.clone()]);
let buffer = create_global_buffer(0);
let idx = UOp::index_const(0);
let index = create_index(buffer, vec![idx]);
let final_graph = index.after(smallvec::smallvec![ended]);
let topo_ranges: HashSet<u64> =
final_graph.toposort().iter().filter(|u| matches!(u.op(), Op::Range { .. })).map(|u| u.id).collect();
let final_in_scope: HashSet<u64> = final_graph.in_scope_ranges().iter().map(|key| key.0.id).collect();
assert!(topo_ranges.contains(&range.id), "Range should be in toposort of final_graph");
assert!(!final_in_scope.contains(&range.id), "Range should NOT be in final_graph's in_scope_ranges");
}
#[test]
fn test_index_in_scope_with_active_range() {
let range = create_local_range(16, 0);
let buffer = create_global_buffer(0);
let index = create_index(buffer, vec![range.clone()]);
let in_scope: HashSet<u64> = index.in_scope_ranges().iter().map(|key| key.0.id).collect();
assert!(in_scope.contains(&range.id), "Active range should be in INDEX's scope");
}
#[test]
fn test_index_scope_with_unused_but_active_range() {
let range = create_local_range(16, 0);
let buffer = create_global_buffer(0);
let constant_idx = UOp::index_const(0);
let _computation = range.add(&UOp::index_const(1));
let index = create_index(buffer, vec![constant_idx]);
let in_scope: HashSet<u64> = index.in_scope_ranges().iter().map(|key| key.0.id).collect();
assert!(!in_scope.contains(&range.id), "Unused range should NOT be in INDEX's scope");
}