use super::id;
use rspirv::dr::{Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{Op, Word};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::bug;
pub fn collect_types(module: &Module) -> FxHashMap<Word, Instruction> {
module
.types_global_values
.iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect()
}
fn composite_count(types: &FxHashMap<Word, Instruction>, ty_id: Word) -> Option<usize> {
let ty = types.get(&ty_id)?;
match ty.class.opcode {
Op::TypeStruct => Some(ty.operands.len()),
Op::TypeVector => Some(ty.operands[1].unwrap_literal_int32() as usize),
Op::TypeArray => {
let length_id = ty.operands[1].unwrap_id_ref();
let const_inst = types.get(&length_id)?;
if const_inst.class.opcode != Op::Constant {
return None;
}
let const_ty = types.get(&const_inst.result_type.unwrap())?;
if const_ty.class.opcode != Op::TypeInt {
return None;
}
let const_value = match const_inst.operands[0] {
Operand::LiteralInt32(v) => v as usize,
Operand::LiteralInt64(v) => v as usize,
_ => bug!(),
};
Some(const_value)
}
_ => None,
}
}
pub fn composite_construct(types: &FxHashMap<Word, Instruction>, function: &mut Function) {
let defs = function
.all_inst_iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect::<FxHashMap<Word, Instruction>>();
for block in &mut function.blocks {
for inst in &mut block.instructions {
if inst.class.opcode != Op::CompositeInsert {
continue;
}
let component_count = match composite_count(types, inst.result_type.unwrap()) {
Some(c) => c,
None => continue,
};
let mut components = vec![None; component_count];
let mut cur_inst: &Instruction = inst;
while cur_inst.class.opcode == Op::CompositeInsert {
if cur_inst.operands.len() != 3 {
break;
}
let value = cur_inst.operands[0].unwrap_id_ref();
let index = cur_inst.operands[2].unwrap_literal_int32() as usize;
if index >= components.len() {
break;
}
if components[index].is_none() {
components[index] = Some(value);
}
cur_inst = match defs.get(&cur_inst.operands[1].unwrap_id_ref()) {
Some(i) => i,
None => break,
};
}
if let Some(composite_construct_operands) = components
.into_iter()
.map(|v| v.map(Operand::IdRef))
.collect::<Option<Vec<_>>>()
{
*inst = Instruction::new(
Op::CompositeConstruct,
inst.result_type,
inst.result_id,
composite_construct_operands,
);
}
}
}
}
#[derive(Debug)]
enum IdentifiedOperand {
Vector(Word),
Scalars(Vec<Word>),
NonValue(Operand),
}
fn get_composite_and_index(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
id: Word,
vector_width: u32,
) -> Option<(Word, u32)> {
let inst = defs.get(&id)?;
if inst.class.opcode != Op::CompositeExtract {
return None;
}
if inst.operands.len() != 2 {
return None;
}
let composite = inst.operands[0].unwrap_id_ref();
let index = inst.operands[1].unwrap_literal_int32();
let composite_def = defs.get(&composite).or_else(|| types.get(&composite))?;
let vector_def = types.get(&composite_def.result_type.unwrap())?;
if vector_def.class.opcode != Op::TypeVector
|| vector_width != vector_def.operands[1].unwrap_literal_int32()
{
return None;
}
Some((composite, index))
}
fn match_vector_operand(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
results: &[&Instruction],
operand_index: usize,
vector_width: u32,
) -> Option<Word> {
let operand_zero = match results[0].operands[operand_index] {
Operand::IdRef(id) => id,
_ => {
return None;
}
};
let composite_zero = match get_composite_and_index(types, defs, operand_zero, vector_width) {
Some((composite_zero, 0)) => composite_zero,
_ => {
return None;
}
};
for (expected_index, result) in results.iter().enumerate().skip(1) {
let operand = match result.operands[operand_index] {
Operand::IdRef(id) => id,
_ => {
return None;
}
};
let (composite, actual_index) =
match get_composite_and_index(types, defs, operand, vector_width) {
Some(x) => x,
None => {
return None;
}
};
if composite != composite_zero || expected_index != actual_index as usize {
return None;
}
}
Some(composite_zero)
}
fn match_vector_or_scalars_operand(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
results: &[&Instruction],
operand_index: usize,
vector_width: u32,
) -> Option<IdentifiedOperand> {
if let Some(composite) = match_vector_operand(types, defs, results, operand_index, vector_width)
{
Some(IdentifiedOperand::Vector(composite))
} else {
let operands = results
.iter()
.map(|inst| match inst.operands[operand_index] {
Operand::IdRef(id) => Some(id),
_ => None,
})
.collect::<Option<Vec<_>>>()?;
Some(IdentifiedOperand::Scalars(operands))
}
}
fn match_all_same_operand(results: &[&Instruction], operand_index: usize) -> Option<Operand> {
let operand_zero = &results[0].operands[operand_index];
if results
.iter()
.skip(1)
.all(|inst| &inst.operands[operand_index] == operand_zero)
{
Some(operand_zero.clone())
} else {
None
}
}
fn match_operands(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
results: &[&Instruction],
vector_width: u32,
) -> Option<Vec<IdentifiedOperand>> {
let operation_opcode = results[0].class.opcode;
if results.iter().skip(1).any(|r| {
r.class.opcode != operation_opcode || r.operands.len() != results[0].operands.len()
}) {
return None;
}
match operation_opcode {
Op::IAdd
| Op::FAdd
| Op::ISub
| Op::FSub
| Op::IMul
| Op::FMul
| Op::UDiv
| Op::SDiv
| Op::FDiv
| Op::UMod
| Op::SRem
| Op::FRem
| Op::FMod
| Op::ShiftRightLogical
| Op::ShiftRightArithmetic
| Op::ShiftLeftLogical
| Op::BitwiseOr
| Op::BitwiseXor
| Op::BitwiseAnd => {
let left = match_vector_or_scalars_operand(types, defs, results, 0, vector_width)?;
let right = match_vector_or_scalars_operand(types, defs, results, 1, vector_width)?;
match (left, right) {
(IdentifiedOperand::Scalars(_), IdentifiedOperand::Scalars(_)) => None,
(left, right) => Some(vec![left, right]),
}
}
Op::SNegate | Op::FNegate | Op::Not | Op::BitReverse => {
let value = match_vector_operand(types, defs, results, 0, vector_width)?;
Some(vec![IdentifiedOperand::Vector(value)])
}
Op::ExtInst => {
let set = match_all_same_operand(results, 0)?;
let instruction = match_all_same_operand(results, 1)?;
let parameters = (2..results[0].operands.len())
.map(|i| match_vector_or_scalars_operand(types, defs, results, i, vector_width));
let operands = IntoIterator::into_iter([
Some(IdentifiedOperand::NonValue(set)),
Some(IdentifiedOperand::NonValue(instruction)),
])
.chain(parameters)
.collect::<Option<Vec<_>>>()?;
if operands
.iter()
.skip(2)
.all(|p| matches!(p, &IdentifiedOperand::Scalars(_)))
{
return None;
}
Some(operands)
}
_ => None,
}
}
fn process_instruction(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, Instruction>,
instructions: &mut Vec<Instruction>,
instruction_index: &mut usize,
) -> Option<Instruction> {
let inst = &instructions[*instruction_index];
if inst.class.opcode != Op::CompositeConstruct {
return None;
}
let inst_result_id = inst.result_id.unwrap();
let vector_ty = inst.result_type.unwrap();
let vector_ty_inst = match types.get(&vector_ty) {
Some(inst) => inst,
_ => return None,
};
if vector_ty_inst.class.opcode != Op::TypeVector {
return None;
}
let vector_width = vector_ty_inst.operands[1].unwrap_literal_int32();
let results = match inst
.operands
.iter()
.map(|op| defs.get(&op.unwrap_id_ref()))
.collect::<Option<Vec<_>>>()
{
Some(r) => r,
None => return None,
};
let operation_opcode = results[0].class.opcode;
let composite_arguments = match_operands(types, defs, &results, vector_width)?;
if operation_opcode == Op::FMul && composite_arguments.len() == 2 {
if let (&IdentifiedOperand::Vector(composite), IdentifiedOperand::Scalars(scalars))
| (IdentifiedOperand::Scalars(scalars), &IdentifiedOperand::Vector(composite)) =
(&composite_arguments[0], &composite_arguments[1])
{
let scalar = scalars[0];
if scalars.iter().skip(1).all(|&s| s == scalar) {
return Some(Instruction::new(
Op::VectorTimesScalar,
inst.result_type,
inst.result_id,
vec![Operand::IdRef(composite), Operand::IdRef(scalar)],
));
}
}
}
let operands = composite_arguments
.into_iter()
.map(|operand| match operand {
IdentifiedOperand::Vector(composite) => Operand::IdRef(composite),
IdentifiedOperand::NonValue(operand) => operand,
IdentifiedOperand::Scalars(scalars) => {
let id = super::id(header);
instructions.insert(
*instruction_index,
Instruction::new(
Op::CompositeConstruct,
Some(vector_ty),
Some(id),
scalars.into_iter().map(Operand::IdRef).collect(),
),
);
*instruction_index += 1;
Operand::IdRef(id)
}
})
.collect();
Some(Instruction::new(
operation_opcode,
Some(vector_ty),
Some(inst_result_id),
operands,
))
}
pub fn vector_ops(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
function: &mut Function,
) {
let defs = function
.all_inst_iter()
.filter_map(|inst| Some((inst.result_id?, inst.clone())))
.collect::<FxHashMap<Word, Instruction>>();
for block in &mut function.blocks {
let mut instruction_index = 0;
while instruction_index < block.instructions.len() {
if let Some(result) = process_instruction(
header,
types,
&defs,
&mut block.instructions,
&mut instruction_index,
) {
block.instructions[instruction_index] = result;
}
instruction_index += 1;
}
}
}
fn can_fuse_bool(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, (usize, Instruction)>,
inst: &Instruction,
) -> bool {
fn constant_value(types: &FxHashMap<Word, Instruction>, val: Word) -> Option<u32> {
let inst = match types.get(&val) {
None => return None,
Some(inst) => inst,
};
if inst.class.opcode != Op::Constant {
return None;
}
match inst.operands[0] {
Operand::LiteralInt32(v) => Some(v),
_ => None,
}
}
fn visit(
types: &FxHashMap<Word, Instruction>,
defs: &FxHashMap<Word, (usize, Instruction)>,
visited: &mut FxHashSet<Word>,
value: Word,
) -> bool {
if visited.insert(value) {
let inst = match defs.get(&value) {
Some((_, inst)) => inst,
None => return false,
};
match inst.class.opcode {
Op::Select => {
constant_value(types, inst.operands[1].unwrap_id_ref()) == Some(1)
&& constant_value(types, inst.operands[2].unwrap_id_ref()) == Some(0)
}
Op::Phi => inst
.operands
.iter()
.step_by(2)
.all(|op| visit(types, defs, visited, op.unwrap_id_ref())),
_ => false,
}
} else {
true
}
}
if inst.class.opcode != Op::INotEqual
|| constant_value(types, inst.operands[1].unwrap_id_ref()) != Some(0)
{
return false;
}
let int_value = inst.operands[0].unwrap_id_ref();
visit(types, defs, &mut FxHashSet::default(), int_value)
}
fn fuse_bool(
header: &mut ModuleHeader,
defs: &FxHashMap<Word, (usize, Instruction)>,
phis_to_insert: &mut Vec<(usize, Instruction)>,
already_mapped: &mut FxHashMap<Word, Word>,
bool_ty: Word,
int_value: Word,
) -> Word {
if let Some(&result) = already_mapped.get(&int_value) {
return result;
}
let (block_of_inst, inst) = defs.get(&int_value).unwrap();
match inst.class.opcode {
Op::Select => inst.operands[0].unwrap_id_ref(),
Op::Phi => {
let result_id = id(header);
already_mapped.insert(int_value, result_id);
let new_phi_args = inst
.operands
.chunks(2)
.flat_map(|arr| {
let phi_value = &arr[0];
let block = &arr[1];
[
Operand::IdRef(fuse_bool(
header,
defs,
phis_to_insert,
already_mapped,
bool_ty,
phi_value.unwrap_id_ref(),
)),
block.clone(),
]
})
.collect::<Vec<_>>();
let inst = Instruction::new(Op::Phi, Some(bool_ty), Some(result_id), new_phi_args);
phis_to_insert.push((*block_of_inst, inst));
result_id
}
_ => bug!("can_fuse_bool should have prevented this case"),
}
}
pub fn bool_fusion(
header: &mut ModuleHeader,
types: &FxHashMap<Word, Instruction>,
function: &mut Function,
) {
let defs: FxHashMap<Word, (usize, Instruction)> = function
.blocks
.iter()
.enumerate()
.flat_map(|(block_id, block)| {
block
.instructions
.iter()
.filter_map(move |inst| Some((inst.result_id?, (block_id, inst.clone()))))
})
.collect();
let mut rewrite_rules = FxHashMap::default();
let mut phis_to_insert = Default::default();
let mut already_mapped = Default::default();
for block in &mut function.blocks {
for inst in &mut block.instructions {
if can_fuse_bool(types, &defs, inst) {
let rewrite_to = fuse_bool(
header,
&defs,
&mut phis_to_insert,
&mut already_mapped,
inst.result_type.unwrap(),
inst.operands[0].unwrap_id_ref(),
);
rewrite_rules.insert(inst.result_id.unwrap(), rewrite_to);
*inst = Instruction::new(Op::Nop, None, None, Vec::new());
}
}
}
for (block, phi) in phis_to_insert {
function.blocks[block].instructions.insert(0, phi);
}
super::apply_rewrite_rules(&rewrite_rules, &mut function.blocks);
}