use super::*;
#[test]
fn test_rounding_mode_rn_matches_none() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rn),
make_add(r4, r2, r3), ];
let result = pass(instructions);
assert_eq!(result.len(), 1, "Explicit Rn should be compatible with None (default)");
}
#[test]
fn test_rounding_mode_none_matches_rn() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1), PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rn),
];
let result = pass(instructions);
assert_eq!(result.len(), 1, "None should be compatible with explicit Rn");
}
#[test]
fn test_rounding_mode_same_explicit_mode() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rz),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rz),
];
let result = pass(instructions);
assert_eq!(result.len(), 1, "Same explicit rounding modes should be compatible");
}
#[test]
fn test_incompatible_rounding_modes_prevent_fusion() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rn),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rz),
];
let result = pass(instructions);
assert_eq!(result.len(), 2, "Incompatible rounding modes should prevent fusion");
}
#[test]
fn test_non_rn_with_none_incompatible() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rz),
make_add(r4, r2, r3),
];
let result = pass(instructions);
assert_eq!(result.len(), 2, "Non-Rn mode with implicit None should be incompatible");
}
#[test]
fn test_rounding_modes_compatible_direct() {
assert!(rounding_modes_compatible(None, None));
assert!(rounding_modes_compatible(None, Some(&RoundingMode::Rn)));
assert!(rounding_modes_compatible(Some(&RoundingMode::Rn), None));
assert!(rounding_modes_compatible(Some(&RoundingMode::Rn), Some(&RoundingMode::Rn)));
assert!(rounding_modes_compatible(Some(&RoundingMode::Rz), Some(&RoundingMode::Rz)));
assert!(rounding_modes_compatible(Some(&RoundingMode::Rp), Some(&RoundingMode::Rp)));
assert!(rounding_modes_compatible(Some(&RoundingMode::Rm), Some(&RoundingMode::Rm)));
assert!(!rounding_modes_compatible(Some(&RoundingMode::Rn), Some(&RoundingMode::Rz)));
assert!(!rounding_modes_compatible(Some(&RoundingMode::Rp), Some(&RoundingMode::Rm)));
assert!(!rounding_modes_compatible(Some(&RoundingMode::Rz), None));
assert!(!rounding_modes_compatible(None, Some(&RoundingMode::Rz)));
}
#[test]
fn test_fma_instruction_structure() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3)];
let result = pass(instructions);
assert_eq!(result.len(), 1);
let fma = &result[0];
assert!(matches!(fma.op, PtxOp::Fma));
assert_eq!(fma.ty, PtxType::F32);
assert_eq!(fma.srcs.len(), 3);
match &fma.dst {
Some(Operand::Reg(r)) => assert_eq!(*r, r4),
_ => panic!("Expected register destination"),
}
assert!(fma.rounding.is_some());
}
#[test]
fn test_add_no_destination() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1),
PtxInstruction::new(PtxOp::Add, PtxType::F32).src(Operand::Reg(r2)).src(Operand::Reg(r0)),
];
let result = pass(instructions);
assert_eq!(result.len(), 2);
}
#[test]
fn test_count_uses_with_immediate_sources() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)), make_add(r1, r0, r0), ];
let counts = count_register_uses(&instructions);
assert_eq!(counts.get(&r0), Some(&2));
}
#[test]
fn test_build_def_map_no_reg_dst() {
let r0 = make_vreg(0, PtxType::F32);
let instr1 = PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0));
let instr2 = PtxInstruction::new(PtxOp::Bra, PtxType::Pred);
let instructions = vec![instr1, instr2];
let defs = build_def_map(&instructions);
assert_eq!(defs.get(&r0), Some(&0));
assert_eq!(defs.len(), 1);
}
#[test]
fn test_add_single_source() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
make_mul(r1, r0, r0),
PtxInstruction::new(PtxOp::Add, PtxType::F32).dst(Operand::Reg(r2)).src(Operand::Reg(r1)),
];
let result = pass(instructions);
assert_eq!(result.len(), 2);
}
#[test]
fn test_multiple_fusions() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let r5 = make_vreg(5, PtxType::F32);
let r6 = make_vreg(6, PtxType::F32);
let r7 = make_vreg(7, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1), make_mul(r5, r3, r4), make_add(r6, r2, r0), make_add(r7, r5, r3), ];
let result = pass(instructions);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, PtxOp::Fma));
assert!(matches!(result[1].op, PtxOp::Fma));
}
#[test]
fn test_fma_inherits_mul_rounding() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rz),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rz),
];
let result = pass(instructions);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rounding, Some(RoundingMode::Rz));
}
#[test]
fn test_special_register_in_sources() {
use super::super::super::super::registers::PtxReg;
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::U32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::U32)
.dst(Operand::Reg(r1))
.src(Operand::SpecialReg(PtxReg::TidX)),
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)),
];
let counts = count_register_uses(&instructions);
assert_eq!(counts.len(), 0);
}