use crate::ir::prelude::*;
use crate::opt::prelude::*;
use crate::{ir::InstData, ty::*, value::IntValue};
use std::cmp::min;
pub struct ConstFolding;
impl Pass for ConstFolding {
fn run_on_inst(_ctx: &PassContext, inst: Inst, unit: &mut UnitBuilder) -> bool {
run_on_inst(unit, inst)
}
}
pub fn run_on_inst(unit: &mut UnitBuilder, inst: Inst) -> bool {
unit.insert_before(inst);
if !unit.has_result(inst) {
return false;
}
let value = unit.inst_result(inst);
let ty = unit.value_type(value);
let data = &unit[inst];
let replacement = match data.opcode() {
Opcode::InsSlice => fold_ins_slice(unit, inst),
Opcode::ExtSlice => fold_ext_slice(unit, inst),
Opcode::ExtField => fold_ext_field(unit, inst),
Opcode::Shl | Opcode::Shr => fold_shift(unit, inst, &ty),
Opcode::Mux => fold_mux(unit, inst),
_ => match *data {
InstData::Unary { opcode, args, .. } => fold_unary(unit, opcode, ty.clone(), args[0]),
InstData::Binary { opcode, args, .. } => fold_binary(unit, opcode, ty.clone(), args),
_ => None,
},
};
if let Some(replacement) = replacement {
let new_ty = unit.value_type(replacement);
assert_eq!(
ty,
new_ty,
"types before (lhs) and after (rhs) folding must match (before: {}, after: {})",
inst.dump(&unit),
unit.get_value_inst(replacement)
.map(|v| v.dump(&unit).to_string())
.unwrap_or_else(|| replacement.dump(&unit).to_string())
);
if let Some(name) = unit.get_name(value).map(String::from) {
unit.set_name(replacement, name);
unit.clear_name(value);
}
unit.replace_use(value, replacement);
true
} else {
false
}
}
pub fn run_on_value(unit: &mut UnitBuilder, value: Value) -> bool {
if let Some(inst) = unit.get_value_inst(value) {
run_on_inst(unit, inst)
} else {
false
}
}
fn fold_unary(unit: &mut UnitBuilder, op: Opcode, ty: Type, arg: Value) -> Option<Value> {
if ty.is_int() {
fold_unary_int(unit, op, arg)
} else {
None
}
}
fn fold_unary_int(unit: &mut UnitBuilder, op: Opcode, arg: Value) -> Option<Value> {
let imm = unit.get_const_int(arg)?;
let result = IntValue::try_unary_op(op, imm)?;
Some(unit.ins().const_int(result))
}
fn fold_binary(unit: &mut UnitBuilder, op: Opcode, ty: Type, args: [Value; 2]) -> Option<Value> {
if ty.is_int() {
fold_binary_int(unit, op, ty.unwrap_int(), args)
} else {
None
}
}
fn fold_binary_int(
unit: &mut UnitBuilder,
op: Opcode,
width: usize,
args: [Value; 2],
) -> Option<Value> {
let imm0 = unit.get_const_int(args[0]);
let imm1 = unit.get_const_int(args[1]);
let (arg_kon, arg_var) = match (imm0, imm1) {
(None, Some(_)) => (imm1, args[0]),
(Some(_), None) => (imm0, args[1]),
_ => (None, args[0]),
};
if let Some(a) = arg_kon {
match op {
Opcode::And | Opcode::Smul | Opcode::Umul if a.is_zero() => {
return Some(unit.ins().const_int(IntValue::zero(width)))
}
Opcode::Or | Opcode::Xor | Opcode::Add | Opcode::Sub if a.is_zero() => {
return Some(arg_var)
}
Opcode::Smul | Opcode::Umul if a.is_one() => return Some(arg_var),
Opcode::Or if a.is_all_ones() => {
return Some(unit.ins().const_int(IntValue::all_ones(width)))
}
Opcode::And if a.is_all_ones() => return Some(arg_var),
Opcode::Xor if a.is_all_ones() => return Some(unit.ins().not(arg_var)),
_ => (),
}
}
let (arg_kon, arg_var) = match (imm0, imm1) {
(None, Some(_)) => (imm1, args[0]),
_ => (None, args[0]),
};
if let Some(a) = arg_kon {
match op {
Opcode::Sdiv | Opcode::Udiv if a.is_one() => return Some(arg_var),
Opcode::Smod | Opcode::Umod | Opcode::Srem | Opcode::Urem if a.is_one() => {
return Some(unit.ins().const_int(IntValue::zero(width)))
}
_ => (),
}
}
let (imm0, imm1) = (imm0?, imm1?);
let result = None
.or_else(|| IntValue::try_binary_op(op, imm0, imm1))
.or_else(|| IntValue::try_compare_op(op, imm0, imm1))?;
Some(unit.ins().const_int(result))
}
fn fold_shift(unit: &mut UnitBuilder, inst: Inst, ty: &Type) -> Option<Value> {
let base = unit[inst].args()[0];
let hidden = unit[inst].args()[1];
let amount = unit[inst].args()[2];
let const_amount = unit.get_const_int(amount);
let left = unit[inst].opcode() == Opcode::Shl;
if const_amount.map(IntValue::is_zero).unwrap_or(false) {
return Some(base);
}
if unit.value_type(base).is_signal() || unit.value_type(base).is_pointer() {
return None;
}
if let Some(amount) = const_amount {
let amount = amount.to_usize();
let base_width = unit.value_type(base).len();
let hidden_width = unit.value_type(hidden).len();
let amount = min(amount, hidden_width);
trace!(
"Fold const shift `{}` (amount: {}, base_width: {}, hidden_width: {})",
inst.dump(&unit),
amount,
base_width,
hidden_width
);
if amount >= base_width {
let offset = if left {
hidden_width - amount
} else {
amount - base_width
};
trace!(" Base fully shifted out; hidden offset {}", offset);
let r = unit.ins().ext_slice(hidden, offset, base_width);
return Some(fold_ext_slice(unit, unit.value_inst(r)).unwrap_or(r));
}
else {
let (b, h, z0, z1, z2) = if left {
let b = unit.ins().ext_slice(base, 0, base_width - amount);
let h = unit.ins().ext_slice(hidden, hidden_width - amount, amount);
let z0 = unit.ins().const_zero(ty);
let z1 = unit.ins().ins_slice(z0, b, amount, base_width - amount);
let z2 = unit.ins().ins_slice(z1, h, 0, amount);
(b, h, z0, z1, z2)
} else {
let h = unit.ins().ext_slice(hidden, 0, amount);
let b = unit.ins().ext_slice(base, amount, base_width - amount);
let z0 = unit.ins().const_zero(ty);
let z1 = unit.ins().ins_slice(z0, h, base_width - amount, amount);
let z2 = unit.ins().ins_slice(z1, b, 0, base_width - amount);
(b, h, z0, z1, z2)
};
run_on_value(unit, h);
run_on_value(unit, b);
run_on_value(unit, z0);
run_on_value(unit, z1);
return Some(fold_ins_slice(unit, unit.value_inst(z2)).unwrap_or(z2));
}
}
None
}
fn fold_ins_slice(unit: &mut UnitBuilder, inst: Inst) -> Option<Value> {
let data = &unit[inst];
let target = data.args()[0];
let value = data.args()[1];
let len = data.imms()[1];
match unit.value_type(target).as_ref() {
IntType(_) | ArrayType(..) if len == 0 => return Some(target),
IntType(w) | ArrayType(w, _) if len == *w => return Some(value),
_ => (),
}
if let (Some(target), Some(value)) = (unit.get_const_int(target), unit.get_const_int(value)) {
let mut r = target.clone();
r.insert_slice(data.imms()[0], len, value);
return Some(unit.ins().const_int(r));
}
None
}
fn fold_ext_slice(unit: &mut UnitBuilder, inst: Inst) -> Option<Value> {
let data = &unit[inst];
let ty = &unit.inst_type(inst);
let target = data.args()[0];
let len = data.imms()[1];
match unit.value_type(target).as_ref() {
IntType(..) | ArrayType(..) if len == 0 => return Some(unit.ins().const_zero(ty)),
IntType(w) | ArrayType(w, _) if len == *w => return Some(target),
_ => (),
}
if let Some(imm) = unit.get_const_int(target) {
let r = imm.extract_slice(data.imms()[0], len);
return Some(unit.ins().const_int(r));
}
None
}
fn fold_ext_field(unit: &mut UnitBuilder, inst: Inst) -> Option<Value> {
let data = &unit[inst];
let target = data.args()[0];
let target_inst = unit.get_value_inst(target)?;
let target_data = &unit[target_inst];
let offset = data.imms()[0];
match target_data.opcode() {
Opcode::ArrayUniform => Some(target_data.args()[0]),
Opcode::Array | Opcode::Struct if offset < target_data.args().len() => {
Some(target_data.args()[offset])
}
_ => None,
}
}
fn fold_mux(unit: &mut UnitBuilder, inst: Inst) -> Option<Value> {
let choices = unit[inst].args()[0];
let sel = unit[inst].args()[1];
let const_sel = unit.get_const_int(sel)?.to_usize();
Some(unit.ins().ext_field(choices, const_sel))
}