use std::collections::HashSet;
use waffle::cfg::CFGInfo;
use waffle::entity::EntityRef;
use waffle::pool::ListRef;
use waffle::{FunctionBody, Operator, Terminator, Type, Value, ValueDef};
enum EscapeAnalysisResult {
Escapes,
NonEscaping(HashSet<Value>),
}
fn shadow_stack_escapes(func: &FunctionBody, cfg: &CFGInfo) -> EscapeAnalysisResult {
let mut tainted = HashSet::new();
for (block_rpo, &block) in cfg.rpo.entries() {
for &inst in &func.blocks[block].insts {
match &func.values[inst] {
&ValueDef::Operator(Operator::GlobalGet { global_index }, _, _)
| &ValueDef::Operator(Operator::GlobalSet { global_index }, _, _)
if global_index.index() == 0 =>
{
log::trace!("tainted because global.get/set: {inst}");
tainted.insert(inst);
}
&ValueDef::Operator(Operator::I32Add, args, _)
| &ValueDef::Operator(Operator::I32Sub, args, _) => {
let args = &func.arg_pool[args];
if args.iter().any(|arg| tainted.contains(arg)) {
log::trace!("tainted because of arg: {inst}");
tainted.insert(inst);
}
}
&ValueDef::Operator(_, args, _) => {
let args = &func.arg_pool[args];
if args.iter().any(|arg| tainted.contains(arg)) {
log::trace!("shadow stack escape due to inst {inst}");
return EscapeAnalysisResult::Escapes;
}
}
&ValueDef::PickOutput(val, _, _) | &ValueDef::Alias(val)
if tainted.contains(&val) =>
{
log::trace!("taint on {val} propagates to {inst} because of alias or pick");
tainted.insert(inst);
}
_ => {}
}
}
match &func.blocks[block].terminator {
&Terminator::CondBr { cond, .. } | &Terminator::Select { value: cond, .. } => {
if tainted.contains(&cond) {
log::trace!("taint on input to conditional branch causes escape: {cond}");
return EscapeAnalysisResult::Escapes;
}
}
&Terminator::Return { ref values } => {
if values.iter().any(|v| tainted.contains(v)) {
log::trace!("taint on return value causes escape");
return EscapeAnalysisResult::Escapes;
}
}
_ => {}
}
let mut escaped_via_term = false;
func.blocks[block].terminator.visit_targets(|target| {
for (arg, (_, param)) in target
.args
.iter()
.zip(func.blocks[target.block].params.iter())
{
if tainted.contains(arg) {
let target_rpo = cfg.rpo_pos[target.block].unwrap();
if target_rpo.index() <= block_rpo.index() {
log::trace!(
"taint traveling on backedge from {} to {} ({} to {}) causes escape",
arg,
param,
block,
target.block
);
escaped_via_term = true;
}
tainted.insert(*param);
}
}
});
if escaped_via_term {
return EscapeAnalysisResult::Escapes;
}
}
EscapeAnalysisResult::NonEscaping(tainted)
}
pub(crate) fn remove_shadow_stack_if_non_escaping(func: &mut FunctionBody, cfg: &CFGInfo) {
if let EscapeAnalysisResult::NonEscaping(values_to_remove) = shadow_stack_escapes(func, &cfg) {
log::trace!("removing shadow stack operations: {values_to_remove:?}");
let ty_u32 = func.type_pool.single(Type::I32);
let const_zero = func.values.push(ValueDef::Operator(
Operator::I32Const { value: 0 },
ListRef::default(),
ty_u32,
));
func.blocks[func.entry].insts.push(const_zero);
for block in func.blocks.values_mut() {
block.insts.retain(|v| !values_to_remove.contains(v));
block.terminator.update_targets(|target| {
for arg in &mut target.args {
if values_to_remove.contains(arg) {
assert_eq!(func.values[*arg].ty(&func.type_pool), Some(Type::I32));
*arg = const_zero;
}
}
});
}
}
}