use std::collections::HashMap;
use crate::{
ir::{function::SsaFunction, ops::SsaOp, value::ConstValue, variable::SsaVarId},
target::Target,
BitSet, PointerSize,
};
pub struct ConstEvaluator<'a, T: Target> {
ssa: &'a SsaFunction<T>,
cache: HashMap<SsaVarId, Option<ConstValue<T>>>,
visiting: BitSet,
max_depth: usize,
pointer_size: PointerSize,
}
impl<'a, T: Target> ConstEvaluator<'a, T> {
const DEFAULT_MAX_DEPTH: usize = 20;
#[must_use]
pub fn new(ssa: &'a SsaFunction<T>, ptr_size: PointerSize) -> Self {
Self::with_max_depth(ssa, Self::DEFAULT_MAX_DEPTH, ptr_size)
}
#[must_use]
pub fn with_max_depth(
ssa: &'a SsaFunction<T>,
max_depth: usize,
ptr_size: PointerSize,
) -> Self {
Self {
ssa,
cache: HashMap::new(),
visiting: BitSet::new(ssa.variable_count().max(1)),
max_depth,
pointer_size: ptr_size,
}
}
pub fn set_known(&mut self, var: SsaVarId, value: ConstValue<T>) {
self.cache.insert(var, Some(value));
}
pub fn evaluate_var(&mut self, var: SsaVarId) -> Option<ConstValue<T>> {
self.evaluate_var_depth(var, 0)
}
fn evaluate_var_depth(&mut self, var: SsaVarId, depth: usize) -> Option<ConstValue<T>> {
if depth > self.max_depth {
return None;
}
if let Some(cached) = self.cache.get(&var) {
return cached.clone();
}
if var.index() < self.visiting.len() && self.visiting.contains(var.index()) {
return None;
}
if var.index() < self.visiting.len() {
self.visiting.insert(var.index());
}
let result = self
.ssa
.get_definition(var)
.and_then(|op| self.evaluate_op_depth(op, depth));
if var.index() < self.visiting.len() {
self.visiting.remove(var.index());
}
self.cache.insert(var, result.clone());
result
}
pub fn evaluate_op(&mut self, op: &SsaOp<T>) -> Option<ConstValue<T>> {
self.evaluate_op_depth(op, 0)
}
fn evaluate_op_depth(&mut self, op: &SsaOp<T>, depth: usize) -> Option<ConstValue<T>> {
if depth > self.max_depth {
return None;
}
if let SsaOp::Copy { src, .. } = op {
return self.evaluate_var_depth(*src, depth.saturating_add(1));
}
let ptr_size = self.pointer_size;
evaluate_const_op(
op,
|var| self.evaluate_var_depth(var, depth.saturating_add(1)),
ptr_size,
)
}
#[must_use]
pub fn into_results(self) -> HashMap<SsaVarId, ConstValue<T>> {
self.cache
.into_iter()
.filter_map(|(var, opt)| opt.map(|val| (var, val)))
.collect()
}
#[must_use]
pub fn ssa(&self) -> &SsaFunction<T> {
self.ssa
}
pub fn clear_cache(&mut self) {
self.cache.clear();
}
}
pub fn evaluate_const_op<T: Target>(
op: &SsaOp<T>,
mut get_const: impl FnMut(SsaVarId) -> Option<ConstValue<T>>,
ptr_size: PointerSize,
) -> Option<ConstValue<T>> {
match op {
SsaOp::Const { value, .. } => Some(value.clone()),
SsaOp::Add { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.add(&r, ptr_size)
}
SsaOp::Sub { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.sub(&r, ptr_size)
}
SsaOp::Mul { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.mul(&r, ptr_size)
}
SsaOp::Div { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.div(&r, ptr_size)
}
SsaOp::Rem { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.rem(&r, ptr_size)
}
SsaOp::Xor { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.bitwise_xor(&r, ptr_size)
}
SsaOp::And { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.bitwise_and(&r, ptr_size)
}
SsaOp::Or { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.bitwise_or(&r, ptr_size)
}
SsaOp::Shl { value, amount, .. } => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
v.shl(&a, ptr_size)
}
SsaOp::Shr {
value,
amount,
unsigned,
..
} => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
v.shr(&a, *unsigned, ptr_size)
}
SsaOp::Neg { operand, .. } => {
let v = get_const(*operand)?;
v.negate(ptr_size)
}
SsaOp::Not { operand, .. } => {
let v = get_const(*operand)?;
v.bitwise_not(ptr_size)
}
SsaOp::Ceq { left, right, .. } => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.ceq(&r)
}
SsaOp::Clt {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
if *unsigned {
l.clt_un(&r)
} else {
l.clt(&r)
}
}
SsaOp::Cgt {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
if *unsigned {
l.cgt_un(&r)
} else {
l.cgt(&r)
}
}
SsaOp::AddOvf {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.add_checked(&r, *unsigned, ptr_size)
}
SsaOp::SubOvf {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.sub_checked(&r, *unsigned, ptr_size)
}
SsaOp::MulOvf {
left,
right,
unsigned,
..
} => {
let l = get_const(*left)?;
let r = get_const(*right)?;
l.mul_checked(&r, *unsigned, ptr_size)
}
SsaOp::Conv {
operand,
target,
overflow_check,
unsigned,
..
} => {
let v = get_const(*operand)?;
let ptr_bytes = ptr_size.bytes() as u32;
if *overflow_check {
v.convert_to_checked(target, *unsigned, ptr_bytes)
} else {
v.convert_to(target, *unsigned, ptr_bytes)
}
}
SsaOp::Rol { value, amount, .. } => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
let shift = a.as_i32()? as u32;
match v {
ConstValue::I8(v) => Some(ConstValue::I8(v.rotate_left(shift))),
ConstValue::I16(v) => Some(ConstValue::I16(v.rotate_left(shift))),
ConstValue::I32(v) => Some(ConstValue::I32(v.rotate_left(shift))),
ConstValue::I64(v) => Some(ConstValue::I64(v.rotate_left(shift))),
ConstValue::U8(v) => Some(ConstValue::U8(v.rotate_left(shift))),
ConstValue::U16(v) => Some(ConstValue::U16(v.rotate_left(shift))),
ConstValue::U32(v) => Some(ConstValue::U32(v.rotate_left(shift))),
ConstValue::U64(v) => Some(ConstValue::U64(v.rotate_left(shift))),
ConstValue::NativeInt(v) => Some(ConstValue::NativeInt(v.rotate_left(shift))),
ConstValue::NativeUInt(v) => Some(ConstValue::NativeUInt(v.rotate_left(shift))),
_ => None,
}
}
SsaOp::Ror { value, amount, .. } => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
let shift = a.as_i32()? as u32;
match v {
ConstValue::I8(v) => Some(ConstValue::I8(v.rotate_right(shift))),
ConstValue::I16(v) => Some(ConstValue::I16(v.rotate_right(shift))),
ConstValue::I32(v) => Some(ConstValue::I32(v.rotate_right(shift))),
ConstValue::I64(v) => Some(ConstValue::I64(v.rotate_right(shift))),
ConstValue::U8(v) => Some(ConstValue::U8(v.rotate_right(shift))),
ConstValue::U16(v) => Some(ConstValue::U16(v.rotate_right(shift))),
ConstValue::U32(v) => Some(ConstValue::U32(v.rotate_right(shift))),
ConstValue::U64(v) => Some(ConstValue::U64(v.rotate_right(shift))),
ConstValue::NativeInt(v) => Some(ConstValue::NativeInt(v.rotate_right(shift))),
ConstValue::NativeUInt(v) => Some(ConstValue::NativeUInt(v.rotate_right(shift))),
_ => None,
}
}
SsaOp::Rcl { value, amount, .. } => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
let shift = a.as_i32()? as u32;
match v {
ConstValue::I8(v) => Some(ConstValue::I8(v.rotate_left(shift))),
ConstValue::I16(v) => Some(ConstValue::I16(v.rotate_left(shift))),
ConstValue::I32(v) => Some(ConstValue::I32(v.rotate_left(shift))),
ConstValue::I64(v) => Some(ConstValue::I64(v.rotate_left(shift))),
ConstValue::U8(v) => Some(ConstValue::U8(v.rotate_left(shift))),
ConstValue::U16(v) => Some(ConstValue::U16(v.rotate_left(shift))),
ConstValue::U32(v) => Some(ConstValue::U32(v.rotate_left(shift))),
ConstValue::U64(v) => Some(ConstValue::U64(v.rotate_left(shift))),
ConstValue::NativeInt(v) => Some(ConstValue::NativeInt(v.rotate_left(shift))),
ConstValue::NativeUInt(v) => Some(ConstValue::NativeUInt(v.rotate_left(shift))),
_ => None,
}
}
SsaOp::Rcr { value, amount, .. } => {
let v = get_const(*value)?;
let a = get_const(*amount)?;
let shift = a.as_i32()? as u32;
match v {
ConstValue::I8(v) => Some(ConstValue::I8(v.rotate_right(shift))),
ConstValue::I16(v) => Some(ConstValue::I16(v.rotate_right(shift))),
ConstValue::I32(v) => Some(ConstValue::I32(v.rotate_right(shift))),
ConstValue::I64(v) => Some(ConstValue::I64(v.rotate_right(shift))),
ConstValue::U8(v) => Some(ConstValue::U8(v.rotate_right(shift))),
ConstValue::U16(v) => Some(ConstValue::U16(v.rotate_right(shift))),
ConstValue::U32(v) => Some(ConstValue::U32(v.rotate_right(shift))),
ConstValue::U64(v) => Some(ConstValue::U64(v.rotate_right(shift))),
ConstValue::NativeInt(v) => Some(ConstValue::NativeInt(v.rotate_right(shift))),
ConstValue::NativeUInt(v) => Some(ConstValue::NativeUInt(v.rotate_right(shift))),
_ => None,
}
}
SsaOp::BSwap { src, .. } => {
let v = get_const(*src)?;
match v {
ConstValue::I16(v) => Some(ConstValue::I16(v.swap_bytes())),
ConstValue::U16(v) => Some(ConstValue::U16(v.swap_bytes())),
ConstValue::I32(v) => Some(ConstValue::I32(v.swap_bytes())),
ConstValue::U32(v) => Some(ConstValue::U32(v.swap_bytes())),
ConstValue::I64(v) => Some(ConstValue::I64(v.swap_bytes())),
ConstValue::U64(v) => Some(ConstValue::U64(v.swap_bytes())),
_ => None,
}
}
SsaOp::BRev { src, .. } => {
let v = get_const(*src)?;
match v {
ConstValue::I8(v) => Some(ConstValue::I8(v.reverse_bits())),
ConstValue::U8(v) => Some(ConstValue::U8(v.reverse_bits())),
ConstValue::I16(v) => Some(ConstValue::I16(v.reverse_bits())),
ConstValue::U16(v) => Some(ConstValue::U16(v.reverse_bits())),
ConstValue::I32(v) => Some(ConstValue::I32(v.reverse_bits())),
ConstValue::U32(v) => Some(ConstValue::U32(v.reverse_bits())),
ConstValue::I64(v) => Some(ConstValue::I64(v.reverse_bits())),
ConstValue::U64(v) => Some(ConstValue::U64(v.reverse_bits())),
_ => None,
}
}
SsaOp::BitScanForward { src, .. } => {
let v = get_const(*src)?;
let bits = match v {
ConstValue::I8(v) => v.trailing_zeros(),
ConstValue::U8(v) => v.trailing_zeros(),
ConstValue::I16(v) => v.trailing_zeros(),
ConstValue::U16(v) => v.trailing_zeros(),
ConstValue::I32(v) => v.trailing_zeros(),
ConstValue::U32(v) => v.trailing_zeros(),
ConstValue::I64(v) => v.trailing_zeros(),
ConstValue::U64(v) => v.trailing_zeros(),
_ => return None,
};
Some(ConstValue::I32(bits as i32))
}
SsaOp::BitScanReverse { src, .. } => {
let v = get_const(*src)?;
let bits = match v {
ConstValue::I8(v) => 7u32.checked_sub(v.leading_zeros())?,
ConstValue::U8(v) => 7u32.checked_sub(v.leading_zeros())?,
ConstValue::I16(v) => 15u32.checked_sub(v.leading_zeros())?,
ConstValue::U16(v) => 15u32.checked_sub(v.leading_zeros())?,
ConstValue::I32(v) => 31u32.checked_sub(v.leading_zeros())?,
ConstValue::U32(v) => 31u32.checked_sub(v.leading_zeros())?,
ConstValue::I64(v) => 63u32.checked_sub(v.leading_zeros())?,
ConstValue::U64(v) => 63u32.checked_sub(v.leading_zeros())?,
_ => return None,
};
Some(ConstValue::U32(bits))
}
SsaOp::Popcount { src, .. } => {
let v = get_const(*src)?;
let count = match v {
ConstValue::I8(v) => v.count_ones(),
ConstValue::U8(v) => v.count_ones(),
ConstValue::I16(v) => v.count_ones(),
ConstValue::U16(v) => v.count_ones(),
ConstValue::I32(v) => v.count_ones(),
ConstValue::U32(v) => v.count_ones(),
ConstValue::I64(v) => v.count_ones(),
ConstValue::U64(v) => v.count_ones(),
_ => return None,
};
Some(ConstValue::I32(count as i32))
}
SsaOp::Parity { src, .. } => {
let v = get_const(*src)?;
let parity = match v {
ConstValue::I8(v) => v.count_ones() % 2,
ConstValue::U8(v) => v.count_ones() % 2,
ConstValue::I16(v) => v.count_ones() % 2,
ConstValue::U16(v) => v.count_ones() % 2,
ConstValue::I32(v) => v.count_ones() % 2,
ConstValue::U32(v) => v.count_ones() % 2,
ConstValue::I64(v) => v.count_ones() % 2,
ConstValue::U64(v) => v.count_ones() % 2,
_ => return None,
};
Some(ConstValue::I32(parity as i32))
}
SsaOp::Select {
condition,
true_val,
false_val,
..
} => {
let cond = get_const(*condition)?;
match cond.as_i64() {
Some(0) => get_const(*false_val),
Some(_) => get_const(*true_val),
None => None,
}
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
ir::{block::SsaBlock, instruction::SsaInstruction, variable::DefSite, VariableOrigin},
testing,
testing::{MockTarget, MockType},
};
type Cv = ConstValue<MockTarget>;
fn var(index: usize) -> SsaVarId {
SsaVarId::from_index(index)
}
fn resolve(id: SsaVarId) -> Option<Cv> {
match id.index() {
0 => Some(Cv::I32(10)),
1 => Some(Cv::I32(3)),
2 => Some(Cv::I32(-1)),
3 => Some(Cv::I32(i32::MAX)),
_ => None,
}
}
#[test]
fn const_evaluator_resolves_consts_and_injected_values() {
let ssa = testing::const_i32_return(42);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
assert_eq!(evaluator.ssa().block_count(), 1);
assert_eq!(evaluator.evaluate_var(var(0)), Some(Cv::I32(42)));
evaluator.set_known(var(0), Cv::I32(7));
assert_eq!(evaluator.evaluate_var(var(0)), Some(Cv::I32(7)));
evaluator.clear_cache();
assert_eq!(evaluator.evaluate_var(var(0)), Some(Cv::I32(42)));
assert_eq!(evaluator.into_results().get(&var(0)), Some(&Cv::I32(42)));
}
#[test]
fn const_evaluator_traces_copy_chains_and_honors_depth_limit() {
let mut ssa = crate::ir::SsaFunction::<MockTarget>::new(0, 0);
for i in 0..3 {
ssa.create_variable(
VariableOrigin::Local(i),
0,
DefSite::instruction(0, usize::from(i)),
MockType::I32,
);
}
let mut block = SsaBlock::new(0);
block.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: var(0),
value: Cv::I32(5),
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Copy {
dest: var(1),
src: var(0),
}));
block.add_instruction(SsaInstruction::synthetic(SsaOp::Copy {
dest: var(2),
src: var(1),
}));
ssa.add_block(block);
let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64);
assert_eq!(evaluator.evaluate_var(var(2)), Some(Cv::I32(5)));
let mut shallow = ConstEvaluator::with_max_depth(&ssa, 0, PointerSize::Bit64);
assert_eq!(shallow.evaluate_var(var(2)), None);
}
#[test]
fn evaluate_const_op_folds_arithmetic_bitwise_and_shifts() {
assert_eq!(
evaluate_const_op(
&SsaOp::Add {
dest: var(9),
left: var(0),
right: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(13))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Sub {
dest: var(9),
left: var(0),
right: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(7))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Mul {
dest: var(9),
left: var(0),
right: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(30))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Div {
dest: var(9),
left: var(0),
right: var(1),
unsigned: false,
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(3))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Rem {
dest: var(9),
left: var(0),
right: var(1),
unsigned: false,
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(1))
);
assert_eq!(
evaluate_const_op(
&SsaOp::And {
dest: var(9),
left: var(0),
right: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(2))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Or {
dest: var(9),
left: var(0),
right: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(11))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Xor {
dest: var(9),
left: var(0),
right: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(9))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Shl {
dest: var(9),
value: var(1),
amount: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(24))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Shr {
dest: var(9),
value: var(2),
amount: var(1),
unsigned: true,
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(536_870_911))
);
}
#[test]
fn evaluate_const_op_folds_unary_comparison_and_checked_arithmetic() {
assert_eq!(
evaluate_const_op(
&SsaOp::Neg {
dest: var(9),
operand: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(-3))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Not {
dest: var(9),
operand: var(1),
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(!3))
);
assert_eq!(
evaluate_const_op(
&SsaOp::Ceq {
dest: var(9),
left: var(0),
right: var(1)
},
resolve,
PointerSize::Bit64
),
Some(Cv::False)
);
assert_eq!(
evaluate_const_op(
&SsaOp::Clt {
dest: var(9),
left: var(1),
right: var(0),
unsigned: false
},
resolve,
PointerSize::Bit64
),
Some(Cv::True)
);
assert_eq!(
evaluate_const_op(
&SsaOp::Cgt {
dest: var(9),
left: var(0),
right: var(1),
unsigned: true
},
resolve,
PointerSize::Bit64
),
Some(Cv::True)
);
assert_eq!(
evaluate_const_op(
&SsaOp::AddOvf {
dest: var(9),
left: var(3),
right: var(1),
unsigned: false,
flags: None,
},
resolve,
PointerSize::Bit64
),
None
);
assert_eq!(
evaluate_const_op(
&SsaOp::SubOvf {
dest: var(9),
left: var(0),
right: var(1),
unsigned: false,
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(7))
);
assert_eq!(
evaluate_const_op(
&SsaOp::MulOvf {
dest: var(9),
left: var(0),
right: var(1),
unsigned: false,
flags: None,
},
resolve,
PointerSize::Bit64
),
Some(Cv::I32(30))
);
}
#[test]
fn evaluate_const_op_returns_none_for_unknown_inputs_and_unsupported_ops() {
assert_eq!(
evaluate_const_op(
&SsaOp::Add {
dest: var(9),
left: var(0),
right: var(99),
flags: None,
},
resolve,
PointerSize::Bit64
),
None
);
assert_eq!(
evaluate_const_op(
&SsaOp::Conv {
dest: var(9),
operand: var(0),
target: MockType::I64,
overflow_check: false,
unsigned: false
},
resolve,
PointerSize::Bit64
),
None
);
assert_eq!(
evaluate_const_op::<MockTarget>(
&SsaOp::Return {
value: Some(var(0))
},
resolve,
PointerSize::Bit64
),
None
);
}
}