use super::*;
use crate::ptx::types::PtxType;
#[test]
fn test_profitability_with_heavy_ops() {
let heavy_instr = PtxInstruction::new(PtxOp::Ld, PtxType::F32);
let light_instr = PtxInstruction::new(PtxOp::Add, PtxType::F32);
assert!(is_split_profitable(std::slice::from_ref(&heavy_instr), 10));
assert!(!is_split_profitable(std::slice::from_ref(&light_instr), 10));
let many_light: Vec<_> = (0..10).map(|_| light_instr.clone()).collect();
assert!(is_split_profitable(&many_light, 10));
}
#[test]
fn test_split_point_alignment() {
assert_eq!(align_split_point(5, 0, 1), 5);
assert_eq!(align_split_point(5, 0, 4), 8);
assert_eq!(align_split_point(8, 0, 4), 8);
assert_eq!(align_split_point(10, 2, 4), 10);
assert_eq!(align_split_point(9, 2, 4), 10);
}
#[test]
fn test_split_handles_boundary() {
assert_eq!(align_split_point(0, 0, 4), 0);
assert_eq!(align_split_point(0, 5, 4), 5);
}
#[test]
fn test_is_heavy_op() {
assert!(is_heavy_op(&PtxOp::Ld));
assert!(is_heavy_op(&PtxOp::St));
assert!(is_heavy_op(&PtxOp::WmmaMma));
assert!(!is_heavy_op(&PtxOp::Add));
assert!(!is_heavy_op(&PtxOp::Mul));
}
#[test]
fn test_loop_split_idempotent() {
let instructions = vec![
PtxInstruction::new(PtxOp::Setp, PtxType::Pred),
PtxInstruction::new(PtxOp::Ld, PtxType::F32),
];
let config = LoopSplitConfig::default();
let first = analyze(&instructions, &config);
let second = analyze(&instructions, &config);
assert!(is_idempotent(&first, &second));
}
#[test]
fn test_loop_predicate_then_is_second() {
assert!(!LoopPredicate::LessThan.then_is_second());
assert!(!LoopPredicate::LessEqual.then_is_second());
assert!(LoopPredicate::GreaterThan.then_is_second());
assert!(LoopPredicate::GreaterEqual.then_is_second());
}
#[test]
fn test_analyze_empty() {
let config = LoopSplitConfig::default();
let result = analyze(&[], &config);
assert!(result.is_empty());
}
#[test]
fn test_analyze_no_setp() {
let instructions = vec![
PtxInstruction::new(PtxOp::Add, PtxType::F32),
PtxInstruction::new(PtxOp::Mul, PtxType::F32),
];
let config = LoopSplitConfig::default();
let result = analyze(&instructions, &config);
assert!(result.is_empty());
}
#[test]
fn test_config_default() {
let config = LoopSplitConfig::default();
assert_eq!(config.threshold, 1);
}
#[test]
fn test_loop_predicate_from_cmp_op() {
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Lt),
Some(LoopPredicate::LessThan)
);
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Le),
Some(LoopPredicate::LessEqual)
);
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Gt),
Some(LoopPredicate::GreaterThan)
);
assert_eq!(
LoopPredicate::from_cmp_op(CmpOp::Ge),
Some(LoopPredicate::GreaterEqual)
);
assert_eq!(LoopPredicate::from_cmp_op(CmpOp::Eq), None);
assert_eq!(LoopPredicate::from_cmp_op(CmpOp::Ne), None);
}
#[test]
fn test_normalize_comparison_lhs_induction_var() {
let iv = VirtualReg::new(0, PtxType::U32);
let bound_reg = VirtualReg::new(1, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::Reg(iv.clone()))
.src(Operand::Reg(bound_reg.clone()));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_some());
let (pred, _bound) = result.expect("test");
assert_eq!(pred, LoopPredicate::LessThan);
}
#[test]
fn test_normalize_comparison_rhs_induction_var() {
let iv = VirtualReg::new(0, PtxType::U32);
let bound_reg = VirtualReg::new(1, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::Reg(bound_reg.clone()))
.src(Operand::Reg(iv.clone()));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_some());
let (pred, _bound) = result.expect("test");
assert_eq!(pred, LoopPredicate::GreaterThan);
}
#[test]
fn test_normalize_comparison_not_setp() {
let iv = VirtualReg::new(0, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Add, PtxType::F32)
.src(Operand::Reg(iv.clone()))
.src(Operand::ImmU64(10));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_normalize_comparison_too_few_sources() {
let iv = VirtualReg::new(0, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred).src(Operand::Reg(iv.clone()));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_normalize_comparison_no_induction_var() {
let iv = VirtualReg::new(0, PtxType::U32);
let other1 = VirtualReg::new(1, PtxType::U32);
let other2 = VirtualReg::new(2, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::Reg(other1))
.src(Operand::Reg(other2));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_normalize_comparison_imm_operands() {
let iv = VirtualReg::new(0, PtxType::U32);
let cmp = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.src(Operand::ImmU64(5))
.src(Operand::ImmU64(10));
let result = normalize_comparison(&cmp, iv);
assert!(result.is_none());
}
#[test]
fn test_analyze_with_setp_and_dst() {
let pred_reg = VirtualReg::new(0, PtxType::Pred);
let setp_instr = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.dst(Operand::Reg(pred_reg))
.src(Operand::ImmU64(0))
.src(Operand::ImmU64(10));
let heavy_instr = PtxInstruction::new(PtxOp::Ld, PtxType::F32);
let instructions = vec![setp_instr, heavy_instr];
let config = LoopSplitConfig::default();
let result = analyze(&instructions, &config);
assert!(!result.is_empty());
}
#[test]
fn test_analyze_with_high_threshold() {
let pred_reg = VirtualReg::new(0, PtxType::Pred);
let setp_instr = PtxInstruction::new(PtxOp::Setp, PtxType::Pred)
.dst(Operand::Reg(pred_reg))
.src(Operand::ImmU64(0))
.src(Operand::ImmU64(10));
let light_instr = PtxInstruction::new(PtxOp::Add, PtxType::F32);
let instructions = vec![setp_instr, light_instr];
let config = LoopSplitConfig { threshold: 100 };
let result = analyze(&instructions, &config);
assert!(result.is_empty() || !result.is_empty()); }
#[test]
fn test_splittable_condition_fields() {
let iv = VirtualReg::new(0, PtxType::U32);
let cond = SplittableCondition {
cmp_idx: 5,
induction_var: iv.clone(),
predicate: LoopPredicate::LessThan,
bound: Operand::ImmU64(100),
if_ops: HashSet::new(),
};
assert_eq!(cond.cmp_idx, 5);
assert_eq!(cond.predicate, LoopPredicate::LessThan);
}
#[test]
fn test_align_split_point_zero_step() {
assert_eq!(align_split_point(10, 0, 0), 10);
}
#[test]
fn test_all_heavy_ops() {
assert!(is_heavy_op(&PtxOp::WmmaLoadA));
assert!(is_heavy_op(&PtxOp::WmmaLoadB));
assert!(is_heavy_op(&PtxOp::WmmaLoadC));
assert!(is_heavy_op(&PtxOp::WmmaStoreD));
}