use std::sync::Arc;
use morok_dtype::DType;
use morok_ir::{AxisType, ConstValue, Op, SInt, TernaryOp, UOp};
use crate::rangeify::indexing::IndexingContext;
#[test]
fn test_identical_ranges_no_realization() {
let mut ctx = IndexingContext::new();
let r0 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let r1 = ctx.new_range(&SInt::Const(20), AxisType::Loop);
let consumer_rngs = [vec![r0.clone(), r1.clone()], vec![r0.clone(), r1.clone()]];
use crate::rangeify::indexing::all_ranges_same;
let indices0: Vec<_> = consumer_rngs[0].iter().map(|r| r.get_idx()).collect();
let indices1: Vec<_> = consumer_rngs[1].iter().map(|r| r.get_idx()).collect();
assert!(all_ranges_same(&[indices0[0].clone(), indices1[0].clone()]));
assert!(all_ranges_same(&[indices0[1].clone(), indices1[1].clone()]));
}
#[test]
fn test_get_idx_plain_range() {
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let idx = range.get_idx();
assert!(Arc::ptr_eq(&idx, &range));
}
#[test]
fn test_get_valid_plain_range() {
let mut ctx = IndexingContext::new();
let range = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let valid = range.get_valid();
if let Op::Const(cv) = valid.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected constant true for plain range validity");
}
}
#[test]
fn test_get_idx_with_validity() {
let mut ctx = IndexingContext::new();
let idx = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let valid = UOp::native_const(true);
let invalid = UOp::invalid_marker();
let wrapped = UOp::try_where(valid, idx.clone(), invalid).unwrap();
let extracted_idx = wrapped.get_idx();
assert!(Arc::ptr_eq(&extracted_idx, &idx));
}
#[test]
fn test_get_valid_with_validity() {
let mut ctx = IndexingContext::new();
let idx = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let five = UOp::index_const(5);
let valid = idx.try_cmplt(&five).unwrap();
let invalid = UOp::invalid_marker();
let wrapped = UOp::try_where(valid.clone(), idx.clone(), invalid).unwrap();
let extracted_valid = wrapped.get_valid();
assert!(Arc::ptr_eq(&extracted_valid, &valid));
}
#[test]
fn test_all_ranges_same_identical() {
let mut ctx = IndexingContext::new();
let r1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let r2 = r1.clone();
use crate::rangeify::indexing::all_ranges_same;
assert!(all_ranges_same(&[r1, r2]));
}
#[test]
fn test_all_ranges_same_different() {
let mut ctx = IndexingContext::new();
let r1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let r2 = ctx.new_range(&SInt::Const(20), AxisType::Loop);
let idx1 = r1.get_idx();
let idx2 = r2.get_idx();
use crate::rangeify::indexing::all_ranges_same;
assert!(!all_ranges_same(&[idx1, idx2]));
}
#[test]
fn test_invalid_marker_detection() {
let invalid = UOp::invalid_marker();
assert!(matches!(invalid.op(), Op::Invalid));
assert_eq!(invalid.dtype(), DType::Index);
}
#[test]
fn test_padding_uses_invalid_marker() {
let idx = UOp::index_const(0);
let valid = UOp::native_const(true);
let invalid = UOp::invalid_marker();
let padded = UOp::try_where(valid, idx, invalid).unwrap();
if let Op::Ternary(TernaryOp::Where, _cond, _true_val, false_val) = padded.op() {
assert!(matches!(false_val.op(), Op::Invalid));
} else {
panic!("Expected WHERE operation");
}
}
#[test]
fn test_or_merging_of_validity_masks() {
let mut ctx = IndexingContext::new();
let idx = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let five = UOp::index_const(5);
let eight = UOp::index_const(8);
let valid1 = idx.try_cmplt(&five).unwrap(); let valid2 = idx.try_cmplt(&eight).unwrap();
let merged = valid1.try_or_op(&valid2).unwrap();
if let Op::Binary(op, _, _) = merged.op() {
assert!(matches!(op, morok_ir::BinaryOp::Or));
} else {
panic!("Expected OR operation");
}
}
#[test]
fn test_empty_ranges_list() {
use crate::rangeify::indexing::all_ranges_same;
assert!(all_ranges_same(&[]));
}
#[test]
fn test_single_range() {
let mut ctx = IndexingContext::new();
let r1 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
use crate::rangeify::indexing::all_ranges_same;
assert!(all_ranges_same(&[r1]));
}
fn create_buffer_with_size(size: usize) -> Arc<UOp> {
UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, size, DType::Float32)
}
fn create_reshaped_2d(sizes: &[usize]) -> Arc<UOp> {
let src = create_buffer_with_size(sizes.iter().product());
let new_shape = UOp::vectorize(sizes.iter().map(|&s| UOp::index_const(s as i64)).collect());
UOp::new(Op::Reshape { src, new_shape }, DType::Float32)
}
#[test]
fn test_merge_consumer_ranges_identical_1d() {
use crate::rangeify::merge_consumer_ranges;
let mut ctx = IndexingContext::new();
let buffer = create_buffer_with_size(100);
let r0 = ctx.new_range(&SInt::Const(100), AxisType::Loop);
let consumer_rngs = vec![vec![r0.clone()], vec![r0.clone()]];
let merged = merge_consumer_ranges(&buffer, &consumer_rngs, &mut ctx).unwrap();
assert_eq!(merged.len(), 1, "Should have 1 merged range");
assert!(Arc::ptr_eq(&merged[0], &r0), "Range should be unchanged");
assert!(
!ctx.realize_map.contains_key(&morok_ir::UOpKey(buffer.clone())),
"Identical ranges should NOT require realization"
);
}
#[test]
fn test_merge_consumer_ranges_different_1d() {
use crate::rangeify::merge_consumer_ranges;
let mut ctx = IndexingContext::new();
let buffer = create_buffer_with_size(100);
let r0_a = ctx.new_range(&SInt::Const(100), AxisType::Loop);
let r0_b = ctx.new_range(&SInt::Const(100), AxisType::Loop);
let consumer_rngs = vec![vec![r0_a.clone()], vec![r0_b.clone()]];
let merged = merge_consumer_ranges(&buffer, &consumer_rngs, &mut ctx).unwrap();
assert_eq!(merged.len(), 1, "Should have 1 merged range");
assert!(!Arc::ptr_eq(&merged[0], &r0_a), "Different ranges should create new range");
assert!(!Arc::ptr_eq(&merged[0], &r0_b), "Different ranges should create new range");
let realize_info = ctx.realize_map.get(&morok_ir::UOpKey(buffer.clone()));
assert!(realize_info.is_some(), "Different ranges should require realization");
}
#[test]
fn test_merge_consumer_ranges_2d_partial_overlap() {
use crate::rangeify::merge_consumer_ranges;
let mut ctx = IndexingContext::new();
let reshaped = create_reshaped_2d(&[10, 20]);
let r0 = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let r1_a = ctx.new_range(&SInt::Const(20), AxisType::Loop);
let r1_b = ctx.new_range(&SInt::Const(20), AxisType::Loop);
let consumer_rngs = vec![vec![r0.clone(), r1_a.clone()], vec![r0.clone(), r1_b.clone()]];
let merged = merge_consumer_ranges(&reshaped, &consumer_rngs, &mut ctx).unwrap();
assert_eq!(merged.len(), 2, "Should have 2 merged ranges");
assert!(!Arc::ptr_eq(&merged[0], &r0), "All dims realized when all_all_same=false");
assert!(!Arc::ptr_eq(&merged[1], &r1_a), "Different second dimension should create new range");
let realize_info = ctx.realize_map.get(&morok_ir::UOpKey(reshaped.clone()));
assert!(realize_info.is_some(), "Should mark for realization");
if let Some(Some(axes)) = realize_info {
assert_eq!(axes, &[0, 1], "Both dimensions should need realization");
}
}
#[test]
fn test_merge_consumer_ranges_with_validity() {
use crate::rangeify::merge_consumer_ranges;
let mut ctx = IndexingContext::new();
let buffer = create_buffer_with_size(10);
let idx = ctx.new_range(&SInt::Const(10), AxisType::Loop);
let five = UOp::index_const(5);
let valid1 = idx.try_cmplt(&five).unwrap();
let invalid = UOp::invalid_marker();
let r0_a = UOp::try_where(valid1.clone(), idx.clone(), invalid.clone()).unwrap();
let eight = UOp::index_const(8);
let valid2 = idx.try_cmplt(&eight).unwrap();
let r0_b = UOp::try_where(valid2.clone(), idx.clone(), invalid).unwrap();
let consumer_rngs = vec![vec![r0_a.clone()], vec![r0_b.clone()]];
let merged = merge_consumer_ranges(&buffer, &consumer_rngs, &mut ctx).unwrap();
assert_eq!(merged.len(), 1, "Should have 1 merged range");
if let Op::Ternary(TernaryOp::Where, merged_valid, merged_idx, _) = merged[0].op() {
assert!(Arc::ptr_eq(merged_idx, &idx), "Merged index should be unchanged");
if let Op::Binary(op, _, _) = merged_valid.op() {
assert!(matches!(op, morok_ir::BinaryOp::Or), "Validity should be OR'd");
} else {
panic!("Expected OR operation in merged validity, got {:?}", merged_valid.op());
}
} else {
panic!("Expected WHERE operation in merged range, got {:?}", merged[0].op());
}
}
#[test]
fn test_merge_consumer_ranges_empty() {
use crate::rangeify::merge_consumer_ranges;
let mut ctx = IndexingContext::new();
let buffer = create_buffer_with_size(10);
let consumer_rngs: Vec<Vec<Arc<UOp>>> = vec![];
let merged = merge_consumer_ranges(&buffer, &consumer_rngs, &mut ctx).unwrap();
assert_eq!(merged.len(), 1, "Should create 1 range for 1-dim buffer");
let realize_info = ctx.realize_map.get(&morok_ir::UOpKey(buffer.clone()));
assert!(realize_info.is_some(), "Should mark for realization");
}