use super::*;
fn make_mul(dst: VirtualReg, a: VirtualReg, b: VirtualReg) -> PtxInstruction {
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(dst))
.src(Operand::Reg(a))
.src(Operand::Reg(b))
}
fn make_add(dst: VirtualReg, a: VirtualReg, b: VirtualReg) -> PtxInstruction {
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(dst))
.src(Operand::Reg(a))
.src(Operand::Reg(b))
}
fn make_vreg(id: u32, ty: PtxType) -> VirtualReg {
VirtualReg::new(id, ty)
}
#[test]
fn test_fma_reduces_instruction_count() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3)];
let result = pass(instructions);
assert_eq!(result.len(), 1, "FMA fusion should reduce 2 instructions to 1");
assert!(matches!(result[0].op, PtxOp::Fma), "Result should be FMA instruction");
}
#[test]
fn test_single_use_detection_prevents_incorrect_fusion() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let r5 = make_vreg(5, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3), make_add(r5, r2, r3)];
let result = pass(instructions);
assert_eq!(result.len(), 3, "Should not fuse when mul result has multiple uses");
assert!(!matches!(result[0].op, PtxOp::Fma), "First instruction should remain mul");
}
#[test]
fn test_fma_fusion_is_idempotent() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3)];
let first_pass = pass(instructions);
let second_pass = pass(first_pass.clone());
assert_eq!(first_pass.len(), second_pass.len(), "FMA fusion should be idempotent");
}
#[test]
fn test_fma_pass_linear_complexity() {
let mut instructions = Vec::with_capacity(1000);
for i in 0..1000 {
let r = make_vreg(i, PtxType::F32);
instructions.push(PtxInstruction::new(PtxOp::Mov, PtxType::F32).dst(Operand::Reg(r)));
}
let start = std::time::Instant::now();
let _result = pass(instructions);
let elapsed = start.elapsed();
assert!(elapsed.as_millis() < 100, "FMA pass should have O(n) complexity, took {:?}", elapsed);
}
#[test]
fn test_fma_preserves_non_fusible() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)),
make_add(r2, r0, r1), ];
let result = pass(instructions);
assert_eq!(result.len(), 2, "Non-fusible instructions should be preserved");
}
#[test]
fn test_empty_input() {
let result = pass(vec![]);
assert!(result.is_empty());
}
#[test]
fn test_integer_ops_not_fused() {
let r0 = make_vreg(0, PtxType::U32);
let r1 = make_vreg(1, PtxType::U32);
let r2 = make_vreg(2, PtxType::U32);
let r3 = make_vreg(3, PtxType::U32);
let r4 = make_vreg(4, PtxType::U32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::U32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1)),
PtxInstruction::new(PtxOp::Add, PtxType::U32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3)),
];
let result = pass(instructions);
assert_eq!(result.len(), 2, "Integer ops should not be fused (no integer FMA)");
}
mod rounding_and_advanced;
mod type_and_edge_cases;