use std::collections::HashMap;
use super::super::instructions::{Operand, PtxInstruction, PtxOp, RoundingMode};
use super::super::registers::VirtualReg;
use super::super::types::PtxType;
#[must_use]
pub fn pass(instructions: Vec<PtxInstruction>) -> Vec<PtxInstruction> {
if instructions.is_empty() {
return instructions;
}
let use_counts = count_register_uses(&instructions);
let definitions = build_def_map(&instructions);
let mut fused_muls: std::collections::HashSet<usize> = std::collections::HashSet::new();
let mut fma_replacements: HashMap<usize, PtxInstruction> = HashMap::new();
for (add_idx, _) in instructions.iter().enumerate() {
if let Some((fma, mul_idx)) =
try_fuse_mul_add(add_idx, &instructions, &use_counts, &definitions)
{
fused_muls.insert(mul_idx);
fma_replacements.insert(add_idx, fma);
}
}
let mut result = Vec::with_capacity(instructions.len() - fused_muls.len());
for (i, instr) in instructions.iter().enumerate() {
if fused_muls.contains(&i) {
continue;
}
if let Some(fma) = fma_replacements.get(&i) {
result.push(fma.clone());
} else {
result.push(instr.clone());
}
}
result
}
fn count_register_uses(instructions: &[PtxInstruction]) -> HashMap<VirtualReg, usize> {
let mut counts = HashMap::new();
for instr in instructions {
for src in &instr.srcs {
if let Operand::Reg(reg) = src {
*counts.entry(*reg).or_insert(0) += 1;
}
}
if let Some(Operand::Reg(reg)) = &instr.predicate.as_ref().map(|p| Operand::Reg(p.reg)) {
*counts.entry(*reg).or_insert(0) += 1;
}
}
counts
}
fn build_def_map(instructions: &[PtxInstruction]) -> HashMap<VirtualReg, usize> {
let mut defs = HashMap::new();
for (i, instr) in instructions.iter().enumerate() {
if let Some(Operand::Reg(reg)) = &instr.dst {
defs.insert(*reg, i);
}
}
defs
}
fn try_fuse_mul_add(
add_idx: usize,
instructions: &[PtxInstruction],
use_counts: &HashMap<VirtualReg, usize>,
definitions: &HashMap<VirtualReg, usize>,
) -> Option<(PtxInstruction, usize)> {
let add_instr = &instructions[add_idx];
if !matches!(add_instr.op, PtxOp::Add) {
return None;
}
if !matches!(add_instr.ty, PtxType::F32 | PtxType::F64) {
return None;
}
for (src_idx, src) in add_instr.srcs.iter().enumerate() {
if let Operand::Reg(mul_result) = src {
if use_counts.get(mul_result) != Some(&1) {
continue;
}
if let Some(&def_idx) = definitions.get(mul_result) {
let mul_instr = &instructions[def_idx];
if !matches!(mul_instr.op, PtxOp::Mul) {
continue;
}
if mul_instr.ty != add_instr.ty {
continue;
}
if !rounding_modes_compatible(
mul_instr.rounding.as_ref(),
add_instr.rounding.as_ref(),
) {
continue;
}
let other_src = if src_idx == 0 {
add_instr.srcs.get(1)?
} else {
add_instr.srcs.first()?
};
if mul_instr.srcs.len() < 2 {
continue;
}
let a = mul_instr.srcs.first()?;
let b = mul_instr.srcs.get(1)?;
let fma = PtxInstruction::new(PtxOp::Fma, add_instr.ty.clone())
.dst(add_instr.dst.clone()?)
.src(a.clone())
.src(b.clone())
.src(other_src.clone())
.rounding(mul_instr.rounding.unwrap_or(RoundingMode::Rn));
return Some((fma, def_idx));
}
}
}
None
}
fn rounding_modes_compatible(a: Option<&RoundingMode>, b: Option<&RoundingMode>) -> bool {
match (a, b) {
(None | Some(RoundingMode::Rn), None) | (None, Some(RoundingMode::Rn)) => true,
(Some(a), Some(b)) => a == b,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mul(dst: VirtualReg, a: VirtualReg, b: VirtualReg) -> PtxInstruction {
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(dst))
.src(Operand::Reg(a))
.src(Operand::Reg(b))
}
fn make_add(dst: VirtualReg, a: VirtualReg, b: VirtualReg) -> PtxInstruction {
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(dst))
.src(Operand::Reg(a))
.src(Operand::Reg(b))
}
fn make_vreg(id: u32, ty: PtxType) -> VirtualReg {
VirtualReg::new(id, ty)
}
#[test]
fn test_fma_reduces_instruction_count() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3)];
let result = pass(instructions);
assert_eq!(
result.len(),
1,
"FMA fusion should reduce 2 instructions to 1"
);
assert!(
matches!(result[0].op, PtxOp::Fma),
"Result should be FMA instruction"
);
}
#[test]
fn test_single_use_detection_prevents_incorrect_fusion() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let r5 = make_vreg(5, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1),
make_add(r4, r2, r3),
make_add(r5, r2, r3),
];
let result = pass(instructions);
assert_eq!(
result.len(),
3,
"Should not fuse when mul result has multiple uses"
);
assert!(
!matches!(result[0].op, PtxOp::Fma),
"First instruction should remain mul"
);
}
#[test]
fn test_fma_fusion_is_idempotent() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3)];
let first_pass = pass(instructions);
let second_pass = pass(first_pass.clone());
assert_eq!(
first_pass.len(),
second_pass.len(),
"FMA fusion should be idempotent"
);
}
#[test]
fn test_fma_pass_linear_complexity() {
let mut instructions = Vec::with_capacity(1000);
for i in 0..1000 {
let r = make_vreg(i, PtxType::F32);
instructions.push(PtxInstruction::new(PtxOp::Mov, PtxType::F32).dst(Operand::Reg(r)));
}
let start = std::time::Instant::now();
let _result = pass(instructions);
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 100,
"FMA pass should have O(n) complexity, took {:?}",
elapsed
);
}
#[test]
fn test_fma_preserves_non_fusible() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)),
make_add(r2, r0, r1), ];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Non-fusible instructions should be preserved"
);
}
#[test]
fn test_empty_input() {
let result = pass(vec![]);
assert!(result.is_empty());
}
#[test]
fn test_integer_ops_not_fused() {
let r0 = make_vreg(0, PtxType::U32);
let r1 = make_vreg(1, PtxType::U32);
let r2 = make_vreg(2, PtxType::U32);
let r3 = make_vreg(3, PtxType::U32);
let r4 = make_vreg(4, PtxType::U32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::U32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1)),
PtxInstruction::new(PtxOp::Add, PtxType::U32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3)),
];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Integer ops should not be fused (no integer FMA)"
);
}
#[test]
fn test_f64_fusion() {
let r0 = make_vreg(0, PtxType::F64);
let r1 = make_vreg(1, PtxType::F64);
let r2 = make_vreg(2, PtxType::F64);
let r3 = make_vreg(3, PtxType::F64);
let r4 = make_vreg(4, PtxType::F64);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F64)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1)),
PtxInstruction::new(PtxOp::Add, PtxType::F64)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3)),
];
let result = pass(instructions);
assert_eq!(result.len(), 1, "F64 mul+add should fuse to FMA");
assert!(matches!(result[0].op, PtxOp::Fma));
assert_eq!(result[0].ty, PtxType::F64);
}
#[test]
fn test_fusion_mul_result_as_second_operand() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r3)) .src(Operand::Reg(r2)), ];
let result = pass(instructions);
assert_eq!(
result.len(),
1,
"Should fuse even when mul result is second add operand"
);
assert!(matches!(result[0].op, PtxOp::Fma));
}
#[test]
fn test_type_mismatch_prevents_fusion() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F64);
let r4 = make_vreg(4, PtxType::F64);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1)),
PtxInstruction::new(PtxOp::Add, PtxType::F64)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3)),
];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Type mismatch between mul and add should prevent fusion"
);
}
#[test]
fn test_rounding_mode_rn_matches_none() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rn),
make_add(r4, r2, r3), ];
let result = pass(instructions);
assert_eq!(
result.len(),
1,
"Explicit Rn should be compatible with None (default)"
);
}
#[test]
fn test_rounding_mode_none_matches_rn() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1), PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rn),
];
let result = pass(instructions);
assert_eq!(
result.len(),
1,
"None should be compatible with explicit Rn"
);
}
#[test]
fn test_rounding_mode_same_explicit_mode() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rz),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rz),
];
let result = pass(instructions);
assert_eq!(
result.len(),
1,
"Same explicit rounding modes should be compatible"
);
}
#[test]
fn test_incompatible_rounding_modes_prevent_fusion() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rn),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rz),
];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Incompatible rounding modes should prevent fusion"
);
}
#[test]
fn test_non_rn_with_none_incompatible() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rz),
make_add(r4, r2, r3),
];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Non-Rn mode with implicit None should be incompatible"
);
}
#[test]
fn test_predicate_use_counts() {
use super::super::super::instructions::Predicate;
let pred = make_vreg(0, PtxType::Pred);
let r0 = make_vreg(1, PtxType::F32);
let r1 = make_vreg(2, PtxType::F32);
let r2 = make_vreg(3, PtxType::F32);
let r3 = make_vreg(4, PtxType::F32);
let r4 = make_vreg(5, PtxType::F32);
let mut mul_instr = PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1));
mul_instr.predicate = Some(Predicate {
reg: pred,
negated: false,
});
let instructions = vec![mul_instr, make_add(r4, r2, r3)];
let result = pass(instructions);
assert_eq!(result.len(), 1, "Predicate usage should not prevent fusion");
}
#[test]
fn test_non_mul_definition_not_fused() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)),
make_add(r2, r0, r1),
];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Add with non-mul source definition should not fuse"
);
}
#[test]
fn test_mul_insufficient_sources() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r1))
.src(Operand::Reg(r0)), make_add(r3, r1, r2),
];
let result = pass(instructions);
assert_eq!(
result.len(),
2,
"Mul with insufficient sources should not fuse"
);
}
#[test]
fn test_add_with_immediate_source() {
let r0 = make_vreg(0, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::ImmF32(2.0)), ];
let result = pass(instructions);
assert_eq!(result.len(), 2, "Add with immediate source should not fuse");
}
#[test]
fn test_undefined_register_source() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![make_add(r2, r0, r1)];
let result = pass(instructions);
assert_eq!(
result.len(),
1,
"Add with undefined register should be preserved"
);
}
#[test]
fn test_rounding_modes_compatible_direct() {
assert!(rounding_modes_compatible(None, None));
assert!(rounding_modes_compatible(None, Some(&RoundingMode::Rn)));
assert!(rounding_modes_compatible(Some(&RoundingMode::Rn), None));
assert!(rounding_modes_compatible(
Some(&RoundingMode::Rn),
Some(&RoundingMode::Rn)
));
assert!(rounding_modes_compatible(
Some(&RoundingMode::Rz),
Some(&RoundingMode::Rz)
));
assert!(rounding_modes_compatible(
Some(&RoundingMode::Rp),
Some(&RoundingMode::Rp)
));
assert!(rounding_modes_compatible(
Some(&RoundingMode::Rm),
Some(&RoundingMode::Rm)
));
assert!(!rounding_modes_compatible(
Some(&RoundingMode::Rn),
Some(&RoundingMode::Rz)
));
assert!(!rounding_modes_compatible(
Some(&RoundingMode::Rp),
Some(&RoundingMode::Rm)
));
assert!(!rounding_modes_compatible(Some(&RoundingMode::Rz), None));
assert!(!rounding_modes_compatible(None, Some(&RoundingMode::Rz)));
}
#[test]
fn test_fma_instruction_structure() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![make_mul(r2, r0, r1), make_add(r4, r2, r3)];
let result = pass(instructions);
assert_eq!(result.len(), 1);
let fma = &result[0];
assert!(matches!(fma.op, PtxOp::Fma));
assert_eq!(fma.ty, PtxType::F32);
assert_eq!(fma.srcs.len(), 3);
match &fma.dst {
Some(Operand::Reg(r)) => assert_eq!(*r, r4),
_ => panic!("Expected register destination"),
}
assert!(fma.rounding.is_some());
}
#[test]
fn test_add_no_destination() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.src(Operand::Reg(r2))
.src(Operand::Reg(r0)),
];
let result = pass(instructions);
assert_eq!(result.len(), 2);
}
#[test]
fn test_count_uses_with_immediate_sources() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)), make_add(r1, r0, r0), ];
let counts = count_register_uses(&instructions);
assert_eq!(counts.get(&r0), Some(&2));
}
#[test]
fn test_build_def_map_no_reg_dst() {
let r0 = make_vreg(0, PtxType::F32);
let instr1 = PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0));
let instr2 = PtxInstruction::new(PtxOp::Bra, PtxType::Pred);
let instructions = vec![instr1, instr2];
let defs = build_def_map(&instructions);
assert_eq!(defs.get(&r0), Some(&0));
assert_eq!(defs.len(), 1);
}
#[test]
fn test_add_single_source() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let instructions = vec![
make_mul(r1, r0, r0),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r1)),
];
let result = pass(instructions);
assert_eq!(result.len(), 2);
}
#[test]
fn test_multiple_fusions() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let r5 = make_vreg(5, PtxType::F32);
let r6 = make_vreg(6, PtxType::F32);
let r7 = make_vreg(7, PtxType::F32);
let instructions = vec![
make_mul(r2, r0, r1), make_mul(r5, r3, r4), make_add(r6, r2, r0), make_add(r7, r5, r3), ];
let result = pass(instructions);
assert_eq!(result.len(), 2);
assert!(matches!(result[0].op, PtxOp::Fma));
assert!(matches!(result[1].op, PtxOp::Fma));
}
#[test]
fn test_fma_inherits_mul_rounding() {
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::F32);
let r2 = make_vreg(2, PtxType::F32);
let r3 = make_vreg(3, PtxType::F32);
let r4 = make_vreg(4, PtxType::F32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mul, PtxType::F32)
.dst(Operand::Reg(r2))
.src(Operand::Reg(r0))
.src(Operand::Reg(r1))
.rounding(RoundingMode::Rz),
PtxInstruction::new(PtxOp::Add, PtxType::F32)
.dst(Operand::Reg(r4))
.src(Operand::Reg(r2))
.src(Operand::Reg(r3))
.rounding(RoundingMode::Rz),
];
let result = pass(instructions);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rounding, Some(RoundingMode::Rz));
}
#[test]
fn test_special_register_in_sources() {
use super::super::super::registers::PtxReg;
let r0 = make_vreg(0, PtxType::F32);
let r1 = make_vreg(1, PtxType::U32);
let instructions = vec![
PtxInstruction::new(PtxOp::Mov, PtxType::U32)
.dst(Operand::Reg(r1))
.src(Operand::SpecialReg(PtxReg::TidX)),
PtxInstruction::new(PtxOp::Mov, PtxType::F32)
.dst(Operand::Reg(r0))
.src(Operand::ImmF32(1.0)),
];
let counts = count_register_uses(&instructions);
assert_eq!(counts.len(), 0);
}
}