use std::collections::HashMap;
use super::super::instructions::{Operand, PtxInstruction, PtxOp, RoundingMode};
use super::super::registers::VirtualReg;
use super::super::types::PtxType;
#[must_use]
pub fn pass(instructions: Vec<PtxInstruction>) -> Vec<PtxInstruction> {
if instructions.is_empty() {
return instructions;
}
let use_counts = count_register_uses(&instructions);
let definitions = build_def_map(&instructions);
let mut fused_muls: std::collections::HashSet<usize> = std::collections::HashSet::new();
let mut fma_replacements: HashMap<usize, PtxInstruction> = HashMap::new();
for (add_idx, _) in instructions.iter().enumerate() {
if let Some((fma, mul_idx)) =
try_fuse_mul_add(add_idx, &instructions, &use_counts, &definitions)
{
fused_muls.insert(mul_idx);
fma_replacements.insert(add_idx, fma);
}
}
let mut result = Vec::with_capacity(instructions.len() - fused_muls.len());
for (i, instr) in instructions.iter().enumerate() {
if fused_muls.contains(&i) {
continue;
}
if let Some(fma) = fma_replacements.get(&i) {
result.push(fma.clone());
} else {
result.push(instr.clone());
}
}
result
}
fn count_register_uses(instructions: &[PtxInstruction]) -> HashMap<VirtualReg, usize> {
let mut counts = HashMap::new();
for instr in instructions {
for src in &instr.srcs {
if let Operand::Reg(reg) = src {
*counts.entry(*reg).or_insert(0) += 1;
}
}
if let Some(Operand::Reg(reg)) = &instr.predicate.as_ref().map(|p| Operand::Reg(p.reg)) {
*counts.entry(*reg).or_insert(0) += 1;
}
}
counts
}
fn build_def_map(instructions: &[PtxInstruction]) -> HashMap<VirtualReg, usize> {
let mut defs = HashMap::new();
for (i, instr) in instructions.iter().enumerate() {
if let Some(Operand::Reg(reg)) = &instr.dst {
defs.insert(*reg, i);
}
}
defs
}
fn try_fuse_mul_add(
add_idx: usize,
instructions: &[PtxInstruction],
use_counts: &HashMap<VirtualReg, usize>,
definitions: &HashMap<VirtualReg, usize>,
) -> Option<(PtxInstruction, usize)> {
let add_instr = &instructions[add_idx];
if !matches!(add_instr.op, PtxOp::Add) {
return None;
}
if !matches!(add_instr.ty, PtxType::F32 | PtxType::F64) {
return None;
}
for (src_idx, src) in add_instr.srcs.iter().enumerate() {
let Operand::Reg(mul_result) = src else {
continue;
};
if let Some(pair) =
try_fuse_source(add_instr, src_idx, *mul_result, instructions, use_counts, definitions)
{
return Some(pair);
}
}
None
}
fn try_fuse_source(
add_instr: &PtxInstruction,
src_idx: usize,
mul_result: VirtualReg,
instructions: &[PtxInstruction],
use_counts: &HashMap<VirtualReg, usize>,
definitions: &HashMap<VirtualReg, usize>,
) -> Option<(PtxInstruction, usize)> {
if use_counts.get(&mul_result) != Some(&1) {
return None;
}
let &def_idx = definitions.get(&mul_result)?;
let mul_instr = &instructions[def_idx];
if !is_fusable_mul(mul_instr, add_instr) {
return None;
}
let other_src = if src_idx == 0 { add_instr.srcs.get(1)? } else { add_instr.srcs.first()? };
let a = mul_instr.srcs.first()?;
let b = mul_instr.srcs.get(1)?;
let fma = PtxInstruction::new(PtxOp::Fma, add_instr.ty.clone())
.dst(add_instr.dst.clone()?)
.src(a.clone())
.src(b.clone())
.src(other_src.clone())
.rounding(mul_instr.rounding.unwrap_or(RoundingMode::Rn));
Some((fma, def_idx))
}
fn is_fusable_mul(mul_instr: &PtxInstruction, add_instr: &PtxInstruction) -> bool {
matches!(mul_instr.op, PtxOp::Mul)
&& mul_instr.ty == add_instr.ty
&& mul_instr.srcs.len() >= 2
&& rounding_modes_compatible(mul_instr.rounding.as_ref(), add_instr.rounding.as_ref())
}
fn rounding_modes_compatible(a: Option<&RoundingMode>, b: Option<&RoundingMode>) -> bool {
match (a, b) {
(None | Some(RoundingMode::Rn), None) | (None, Some(RoundingMode::Rn)) => true,
(Some(a), Some(b)) => a == b,
_ => false,
}
}
#[cfg(test)]
mod tests;